mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-30 17:40:05 +08:00
Updated the contraction code to ensure that full contraction return a tensor of rank 0
This commit is contained in:
parent
b300a84989
commit
06d774bf58
@ -37,7 +37,7 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
|
||||
typedef typename remove_reference<RhsNested>::type _RhsNested;
|
||||
|
||||
// From NumDims below.
|
||||
static const int NumDimensions = max_n_1<traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value>::size;
|
||||
static const int NumDimensions = traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value;
|
||||
static const int Layout = traits<LhsXprType>::Layout;
|
||||
|
||||
enum {
|
||||
@ -65,7 +65,7 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
|
||||
typedef Device_ Device;
|
||||
|
||||
// From NumDims below.
|
||||
static const int NumDimensions = max_n_1<traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value>::size;
|
||||
static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value;
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
@ -140,7 +140,7 @@ struct TensorContractionEvaluatorBase
|
||||
static const int RDims =
|
||||
internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
|
||||
static const int ContractDims = internal::array_size<Indices>::value;
|
||||
static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
|
||||
static const int NumDims = LDims + RDims - 2 * ContractDims;
|
||||
|
||||
typedef array<Index, ContractDims> contract_t;
|
||||
typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
|
||||
@ -218,11 +218,9 @@ struct TensorContractionEvaluatorBase
|
||||
rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i];
|
||||
}
|
||||
|
||||
m_i_strides[0] = 1;
|
||||
m_j_strides[0] = 1;
|
||||
if(ContractDims) {
|
||||
m_k_strides[0] = 1;
|
||||
}
|
||||
if (m_i_strides.size() > 0) m_i_strides[0] = 1;
|
||||
if (m_j_strides.size() > 0) m_j_strides[0] = 1;
|
||||
if (m_k_strides.size() > 0) m_k_strides[0] = 1;
|
||||
|
||||
m_i_size = 1;
|
||||
m_j_size = 1;
|
||||
@ -318,11 +316,6 @@ struct TensorContractionEvaluatorBase
|
||||
}
|
||||
}
|
||||
|
||||
// Scalar case. We represent the result as a 1d tensor of size 1.
|
||||
if (LDims + RDims == 2 * ContractDims) {
|
||||
m_dimensions[0] = 1;
|
||||
}
|
||||
|
||||
// If the layout is RowMajor, we need to reverse the m_dimensions
|
||||
if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) {
|
||||
for (int i = 0, j = NumDims - 1; i < j; i++, j--) {
|
||||
@ -607,15 +600,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
static const int ContractDims = internal::array_size<Indices>::value;
|
||||
|
||||
typedef array<Index, ContractDims> contract_t;
|
||||
typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
|
||||
typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
|
||||
typedef array<Index, LDims - ContractDims> left_nocontract_t;
|
||||
typedef array<Index, RDims - ContractDims> right_nocontract_t;
|
||||
|
||||
static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
|
||||
static const int NumDims = LDims + RDims - 2 * ContractDims;
|
||||
|
||||
// Could we use NumDimensions here?
|
||||
typedef DSizes<Index, NumDims> Dimensions;
|
||||
|
||||
|
||||
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
|
||||
Base(op, device) { }
|
||||
|
||||
|
@ -87,19 +87,14 @@ static void test_scalar()
|
||||
vec1.setRandom();
|
||||
vec2.setRandom();
|
||||
|
||||
Tensor<float, 1, DataLayout> scalar(1);
|
||||
scalar.setZero();
|
||||
Eigen::array<DimPair, 1> dims = {{DimPair(0, 0)}};
|
||||
typedef TensorEvaluator<decltype(vec1.contract(vec2, dims)), DefaultDevice> Evaluator;
|
||||
Evaluator eval(vec1.contract(vec2, dims), DefaultDevice());
|
||||
eval.evalTo(scalar.data());
|
||||
EIGEN_STATIC_ASSERT(Evaluator::NumDims==1ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
Tensor<float, 0, DataLayout> scalar = vec1.contract(vec2, dims);
|
||||
|
||||
float expected = 0.0f;
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
expected += vec1(i) * vec2(i);
|
||||
}
|
||||
VERIFY_IS_APPROX(scalar(0), expected);
|
||||
VERIFY_IS_APPROX(scalar(), expected);
|
||||
}
|
||||
|
||||
template<int DataLayout>
|
||||
|
Loading…
Reference in New Issue
Block a user