Updated the contraction code to ensure that full contraction return a tensor of rank 0

This commit is contained in:
Benoit Steiner 2016-05-05 08:37:47 -07:00
parent b300a84989
commit 06d774bf58
2 changed files with 11 additions and 24 deletions

View File

@ -37,7 +37,7 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
typedef typename remove_reference<RhsNested>::type _RhsNested;
// From NumDims below.
static const int NumDimensions = max_n_1<traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value>::size;
static const int NumDimensions = traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value;
static const int Layout = traits<LhsXprType>::Layout;
enum {
@ -65,7 +65,7 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
typedef Device_ Device;
// From NumDims below.
static const int NumDimensions = max_n_1<traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value>::size;
static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value;
};
} // end namespace internal
@ -140,7 +140,7 @@ struct TensorContractionEvaluatorBase
static const int RDims =
internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
static const int ContractDims = internal::array_size<Indices>::value;
static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
static const int NumDims = LDims + RDims - 2 * ContractDims;
typedef array<Index, ContractDims> contract_t;
typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
@ -218,11 +218,9 @@ struct TensorContractionEvaluatorBase
rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i];
}
m_i_strides[0] = 1;
m_j_strides[0] = 1;
if(ContractDims) {
m_k_strides[0] = 1;
}
if (m_i_strides.size() > 0) m_i_strides[0] = 1;
if (m_j_strides.size() > 0) m_j_strides[0] = 1;
if (m_k_strides.size() > 0) m_k_strides[0] = 1;
m_i_size = 1;
m_j_size = 1;
@ -318,11 +316,6 @@ struct TensorContractionEvaluatorBase
}
}
// Scalar case. We represent the result as a 1d tensor of size 1.
if (LDims + RDims == 2 * ContractDims) {
m_dimensions[0] = 1;
}
// If the layout is RowMajor, we need to reverse the m_dimensions
if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) {
for (int i = 0, j = NumDims - 1; i < j; i++, j--) {
@ -607,15 +600,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
static const int ContractDims = internal::array_size<Indices>::value;
typedef array<Index, ContractDims> contract_t;
typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
typedef array<Index, LDims - ContractDims> left_nocontract_t;
typedef array<Index, RDims - ContractDims> right_nocontract_t;
static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
static const int NumDims = LDims + RDims - 2 * ContractDims;
// Could we use NumDimensions here?
typedef DSizes<Index, NumDims> Dimensions;
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
Base(op, device) { }

View File

@ -87,19 +87,14 @@ static void test_scalar()
vec1.setRandom();
vec2.setRandom();
Tensor<float, 1, DataLayout> scalar(1);
scalar.setZero();
Eigen::array<DimPair, 1> dims = {{DimPair(0, 0)}};
typedef TensorEvaluator<decltype(vec1.contract(vec2, dims)), DefaultDevice> Evaluator;
Evaluator eval(vec1.contract(vec2, dims), DefaultDevice());
eval.evalTo(scalar.data());
EIGEN_STATIC_ASSERT(Evaluator::NumDims==1ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
Tensor<float, 0, DataLayout> scalar = vec1.contract(vec2, dims);
float expected = 0.0f;
for (int i = 0; i < 6; ++i) {
expected += vec1(i) * vec2(i);
}
VERIFY_IS_APPROX(scalar(0), expected);
VERIFY_IS_APPROX(scalar(), expected);
}
template<int DataLayout>