Fixing Argmax that was breaking upstream TensorFlow.

This commit is contained in:
Benoit Steiner 2017-07-22 03:19:34 +00:00
parent f0b154a4b0
commit 84d7be103a

View File

@ -40,7 +40,7 @@ struct traits<TensorTupleReducerDeviceOp<StrideDims, XprType> > : public traits<
typedef traits<XprType> XprTraits;
typedef typename XprTraits::StorageKind StorageKind;
typedef typename XprTraits::Index Index;
typedef typename XprType::Scalar Scalar;
typedef Index Scalar;
typedef typename XprType::Nested Nested;
typedef typename remove_reference<Nested>::type _Nested;
static const int NumDimensions = XprTraits::NumDimensions;
@ -58,7 +58,8 @@ class TensorTupleReducerDeviceOp : public TensorBase<TensorTupleReducerDeviceOp<
typedef typename Eigen::internal::nested<TensorTupleReducerDeviceOp>::type Nested;
typedef typename Eigen::internal::traits<TensorTupleReducerDeviceOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorTupleReducerDeviceOp>::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::CoeffReturnType TupleType;
typedef Index CoeffReturnType;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTupleReducerDeviceOp(XprType expr,
const Index return_dim,
@ -99,9 +100,9 @@ struct TensorEvaluator<const TensorTupleReducerDeviceOp<StrideDims, ArgType>, Sy
{
typedef TensorTupleReducerDeviceOp<StrideDims, ArgType> XprType;
typedef typename XprType::Index Index;
typedef typename XprType::Index Scalar;
typedef Index CoeffReturnType;
typedef typename XprType::CoeffReturnType TupleType;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::TupleType TupleType;
typedef typename TensorEvaluator<ArgType, SyclKernelDevice>::Dimensions Dimensions;
enum {