Properly record the rank of reduced tensors in the tensor traits.

This commit is contained in:
Benoit Steiner 2016-01-13 14:24:37 -08:00
parent 79b69b7444
commit 9f013a9d86
2 changed files with 7 additions and 4 deletions

View File

@ -134,7 +134,7 @@ struct traits<TensorTupleReducerOp<ReduceOp, Dims, XprType> > : public traits<Xp
typedef Index Scalar;
typedef typename XprType::Nested Nested;
typedef typename remove_reference<Nested>::type _Nested;
static const int NumDimensions = XprTraits::NumDimensions;
static const int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
static const int Layout = XprTraits::Layout;
};

View File

@ -24,11 +24,14 @@ template<typename Op, typename Dims, typename XprType>
struct traits<TensorReductionOp<Op, Dims, XprType> >
: traits<XprType>
{
typedef typename traits<XprType>::Scalar Scalar;
typedef traits<XprType> XprTraits;
typedef typename XprTraits::Scalar Scalar;
typedef typename internal::packet_traits<Scalar>::type Packet;
typedef typename traits<XprType>::StorageKind StorageKind;
typedef typename traits<XprType>::Index Index;
typedef typename XprTraits::StorageKind StorageKind;
typedef typename XprTraits::Index Index;
typedef typename XprType::Nested Nested;
static const int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
static const int Layout = XprTraits::Layout;
};
template<typename Op, typename Dims, typename XprType>