Properly fixing the PointerType for TensorCustomOp.h. As the output type here should be based on CoeffreturnType not the Scalar type. Therefore, Similar to reduction and evalTo function, it should have its own MakePointer class. In this case, for other device the type is defaulted to CoeffReturnType and no changes is required on users' code. However, in SYCL, on the device, we can recunstruct the device Type.

This commit is contained in:
Mehdi Goli 2018-08-09 13:57:43 +01:00
parent 3055e3a7c2
commit 8c083bfd0e
2 changed files with 48 additions and 33 deletions

View File

@ -20,8 +20,8 @@ namespace Eigen {
* *
*/ */
namespace internal { namespace internal {
template<typename CustomUnaryFunc, typename XprType> template<typename CustomUnaryFunc, typename XprType, template <class> class MakePointer_>
struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> > struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType, MakePointer_> >
{ {
typedef typename XprType::Scalar Scalar; typedef typename XprType::Scalar Scalar;
typedef typename XprType::StorageKind StorageKind; typedef typename XprType::StorageKind StorageKind;
@ -30,27 +30,35 @@ struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
typedef typename remove_reference<Nested>::type _Nested; typedef typename remove_reference<Nested>::type _Nested;
static const int NumDimensions = traits<XprType>::NumDimensions; static const int NumDimensions = traits<XprType>::NumDimensions;
static const int Layout = traits<XprType>::Layout; static const int Layout = traits<XprType>::Layout;
typedef typename traits<XprType>::PointerType PointerType;
template <class T> struct MakePointer {
// Intermediate typedef to workaround MSVC issue.
typedef MakePointer_<T> MakePointerT;
typedef typename MakePointerT::Type Type;
typedef typename MakePointerT::RefType RefType;
typedef typename MakePointerT::ScalarType ScalarType;
};
typedef typename MakePointer<typename internal::remove_const<typename XprType::CoeffReturnType>::type>::Type PointerType;
}; };
template<typename CustomUnaryFunc, typename XprType> template<typename CustomUnaryFunc, typename XprType, template <class> class MakePointer_>
struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense> struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType, MakePointer_>, Eigen::Dense>
{ {
typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>& type; typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType, MakePointer_>& type;
}; };
template<typename CustomUnaryFunc, typename XprType> template<typename CustomUnaryFunc, typename XprType, template <class> class MakePointer_>
struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> > struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType, MakePointer_> >
{ {
typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type; typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType, MakePointer_> type;
}; };
} // end namespace internal } // end namespace internal
template<typename CustomUnaryFunc, typename XprType> template<typename CustomUnaryFunc, typename XprType, template <class> class MakePointer_>
class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors> class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType, MakePointer_>, ReadOnlyAccessors>
{ {
public: public:
typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar; typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar;
@ -77,10 +85,10 @@ class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFun
// Eval as rvalue // Eval as rvalue
template<typename CustomUnaryFunc, typename XprType, typename Device> template<typename CustomUnaryFunc, typename XprType, template <class> class MakePointer_, typename Device>
struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device> struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType, MakePointer_>, Device>
{ {
typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> ArgType; typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType, MakePointer_> ArgType;
typedef typename internal::traits<ArgType>::Index Index; typedef typename internal::traits<ArgType>::Index Index;
static const int NumDims = internal::traits<ArgType>::NumDimensions; static const int NumDims = internal::traits<ArgType>::NumDimensions;
typedef DSizes<Index, NumDims> Dimensions; typedef DSizes<Index, NumDims> Dimensions;
@ -88,7 +96,7 @@ struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Devi
typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType; typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
static const int PacketSize = PacketType<CoeffReturnType, Device>::size; static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
typedef typename internal::remove_all<typename Eigen::internal::traits<XprType>::PointerType>::type * PointerType; typedef typename Eigen::internal::traits<ArgType>::PointerType PointerType;
enum { enum {
IsAligned = false, IsAligned = false,
@ -112,7 +120,7 @@ struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Devi
evalTo(data); evalTo(data);
return false; return false;
} else { } else {
m_result = static_cast<CoeffReturnType*>( m_result = static_cast<PointerType>(
m_device.allocate_temp(dimensions().TotalSize() * sizeof(Scalar))); m_device.allocate_temp(dimensions().TotalSize() * sizeof(Scalar)));
evalTo(m_result); evalTo(m_result);
return true; return true;
@ -168,8 +176,8 @@ struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Devi
* *
*/ */
namespace internal { namespace internal {
template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, template <class> class MakePointer_>
struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> > struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType, MakePointer_> >
{ {
typedef typename internal::promote_storage_type<typename LhsXprType::Scalar, typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
typename RhsXprType::Scalar>::ret Scalar; typename RhsXprType::Scalar>::ret Scalar;
@ -185,28 +193,35 @@ struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
typedef typename remove_reference<RhsNested>::type _RhsNested; typedef typename remove_reference<RhsNested>::type _RhsNested;
static const int NumDimensions = traits<LhsXprType>::NumDimensions; static const int NumDimensions = traits<LhsXprType>::NumDimensions;
static const int Layout = traits<LhsXprType>::Layout; static const int Layout = traits<LhsXprType>::Layout;
typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>::type PointerType; template <class T> struct MakePointer {
// Intermediate typedef to workaround MSVC issue.
typedef MakePointer_<T> MakePointerT;
typedef typename MakePointerT::Type Type;
typedef typename MakePointerT::RefType RefType;
typedef typename MakePointerT::ScalarType ScalarType;
};
typedef typename MakePointer<CoeffReturnType>::Type PointerType;
}; };
template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, template <class> class MakePointer_>
struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense> struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType, MakePointer_>, Eigen::Dense>
{ {
typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type; typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type;
}; };
template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, template <class> class MakePointer_>
struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> > struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType, MakePointer_> >
{ {
typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type; typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType, MakePointer_> type;
}; };
} // end namespace internal } // end namespace internal
template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType,template <class> class MakePointer_>
class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors> class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType, MakePointer_>, ReadOnlyAccessors>
{ {
public: public:
typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar; typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar;
@ -239,10 +254,10 @@ class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinary
// Eval as rvalue // Eval as rvalue
template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device> template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, template <class> class MakePointer_, typename Device>
struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device> struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType, MakePointer_>, Device>
{ {
typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> XprType; typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType, MakePointer_> XprType;
typedef typename internal::traits<XprType>::Index Index; typedef typename internal::traits<XprType>::Index Index;
static const int NumDims = internal::traits<XprType>::NumDimensions; static const int NumDims = internal::traits<XprType>::NumDimensions;
typedef DSizes<Index, NumDims> Dimensions; typedef DSizes<Index, NumDims> Dimensions;
@ -250,7 +265,7 @@ struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType,
typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType; typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
static const int PacketSize = PacketType<CoeffReturnType, Device>::size; static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
typedef typename internal::remove_all<typename Eigen::internal::traits<XprType>::PointerType>::type * PointerType; typedef typename Eigen::internal::traits<XprType>::PointerType PointerType;
enum { enum {
IsAligned = false, IsAligned = false,

View File

@ -89,8 +89,8 @@ template<typename LeftXprType, typename RightXprType> class TensorAssignOp;
template<typename Op, typename XprType> class TensorScanOp; template<typename Op, typename XprType> class TensorScanOp;
template<typename Dims, typename XprType> class TensorTraceOp; template<typename Dims, typename XprType> class TensorTraceOp;
template<typename CustomUnaryFunc, typename XprType> class TensorCustomUnaryOp; template<typename CustomUnaryFunc, typename XprType, template <class> class MakePointer_ = MakePointer> class TensorCustomUnaryOp;
template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> class TensorCustomBinaryOp; template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, template <class> class MakePointer_ = MakePointer> class TensorCustomBinaryOp;
template<typename XprType, template <class> class MakePointer_ = MakePointer> class TensorEvalToOp; template<typename XprType, template <class> class MakePointer_ = MakePointer> class TensorEvalToOp;
template<typename XprType> class TensorForcedEvalOp; template<typename XprType> class TensorForcedEvalOp;