mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-13 18:37:27 +08:00
Allowed tensor contraction operation with an empty array of dimension pairs, which performs a tensor product.
This commit is contained in:
parent
4b3052c54d
commit
2195822df6
@ -66,7 +66,7 @@ class BaseTensorContractionMapper {
|
||||
const bool left = (side == Lhs);
|
||||
Index nocontract_val = left ? row : col;
|
||||
Index linidx = 0;
|
||||
for (int i = array_size<nocontract_t>::value - 1; i > 0; i--) {
|
||||
for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
|
||||
const Index idx = nocontract_val / m_ij_strides[i];
|
||||
linidx += idx * m_nocontract_strides[i];
|
||||
nocontract_val -= idx * m_ij_strides[i];
|
||||
@ -81,17 +81,19 @@ class BaseTensorContractionMapper {
|
||||
}
|
||||
|
||||
Index contract_val = left ? col : row;
|
||||
for (int i = array_size<contract_t>::value - 1; i > 0; i--) {
|
||||
for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
|
||||
const Index idx = contract_val / m_k_strides[i];
|
||||
linidx += idx * m_contract_strides[i];
|
||||
contract_val -= idx * m_k_strides[i];
|
||||
}
|
||||
EIGEN_STATIC_ASSERT(array_size<contract_t>::value > 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
if (side == Rhs && inner_dim_contiguous) {
|
||||
eigen_assert(m_contract_strides[0] == 1);
|
||||
linidx += contract_val;
|
||||
} else {
|
||||
linidx += contract_val * m_contract_strides[0];
|
||||
|
||||
if(array_size<contract_t>::value > 0) {
|
||||
if (side == Rhs && inner_dim_contiguous) {
|
||||
eigen_assert(m_contract_strides[0] == 1);
|
||||
linidx += contract_val;
|
||||
} else {
|
||||
linidx += contract_val * m_contract_strides[0];
|
||||
}
|
||||
}
|
||||
|
||||
return linidx;
|
||||
@ -102,7 +104,7 @@ class BaseTensorContractionMapper {
|
||||
const bool left = (side == Lhs);
|
||||
Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
|
||||
Index linidx[2] = {0, 0};
|
||||
for (int i = array_size<nocontract_t>::value - 1; i > 0; i--) {
|
||||
for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
|
||||
const Index idx0 = nocontract_val[0] / m_ij_strides[i];
|
||||
const Index idx1 = nocontract_val[1] / m_ij_strides[i];
|
||||
linidx[0] += idx0 * m_nocontract_strides[i];
|
||||
@ -122,7 +124,7 @@ class BaseTensorContractionMapper {
|
||||
}
|
||||
|
||||
Index contract_val[2] = {left ? col : row, left ? col : row + distance};
|
||||
for (int i = array_size<contract_t>::value - 1; i > 0; i--) {
|
||||
for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
|
||||
const Index idx0 = contract_val[0] / m_k_strides[i];
|
||||
const Index idx1 = contract_val[1] / m_k_strides[i];
|
||||
linidx[0] += idx0 * m_contract_strides[i];
|
||||
@ -130,7 +132,7 @@ class BaseTensorContractionMapper {
|
||||
contract_val[0] -= idx0 * m_k_strides[i];
|
||||
contract_val[1] -= idx1 * m_k_strides[i];
|
||||
}
|
||||
EIGEN_STATIC_ASSERT(array_size<contract_t>::value > 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
|
||||
if (side == Rhs && inner_dim_contiguous) {
|
||||
eigen_assert(m_contract_strides[0] == 1);
|
||||
linidx[0] += contract_val[0];
|
||||
@ -509,8 +511,6 @@ struct TensorContractionEvaluatorBase
|
||||
static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
|
||||
YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
|
||||
eigen_assert((internal::array_size<contract_t>::value > 0) && "Must contract on some indices");
|
||||
|
||||
|
||||
DSizes<Index, LDims> eval_left_dims;
|
||||
DSizes<Index, RDims> eval_right_dims;
|
||||
@ -558,7 +558,9 @@ struct TensorContractionEvaluatorBase
|
||||
|
||||
m_i_strides[0] = 1;
|
||||
m_j_strides[0] = 1;
|
||||
m_k_strides[0] = 1;
|
||||
if(ContractDims) {
|
||||
m_k_strides[0] = 1;
|
||||
}
|
||||
|
||||
m_i_size = 1;
|
||||
m_j_size = 1;
|
||||
|
@ -448,6 +448,31 @@ static void test_small_blocking_factors()
|
||||
}
|
||||
}
|
||||
|
||||
template<int DataLayout>
|
||||
static void test_tensor_product()
|
||||
{
|
||||
Tensor<float, 2, DataLayout> mat1(2, 3);
|
||||
Tensor<float, 2, DataLayout> mat2(4, 1);
|
||||
mat1.setRandom();
|
||||
mat2.setRandom();
|
||||
|
||||
Tensor<float, 4, DataLayout> result = mat1.contract(mat2, Eigen::array<DimPair, 0>{{}});
|
||||
|
||||
VERIFY_IS_EQUAL(result.dimension(0), 2);
|
||||
VERIFY_IS_EQUAL(result.dimension(1), 3);
|
||||
VERIFY_IS_EQUAL(result.dimension(2), 4);
|
||||
VERIFY_IS_EQUAL(result.dimension(3), 1);
|
||||
for (int i = 0; i < result.dimension(0); ++i) {
|
||||
for (int j = 0; j < result.dimension(1); ++j) {
|
||||
for (int k = 0; k < result.dimension(2); ++k) {
|
||||
for (int l = 0; l < result.dimension(3); ++l) {
|
||||
VERIFY_IS_APPROX(result(i, j, k, l), mat1(i, j) * mat2(k, l) );
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void test_cxx11_tensor_contraction()
|
||||
{
|
||||
@ -477,4 +502,6 @@ void test_cxx11_tensor_contraction()
|
||||
CALL_SUBTEST(test_tensor_vector<RowMajor>());
|
||||
CALL_SUBTEST(test_small_blocking_factors<ColMajor>());
|
||||
CALL_SUBTEST(test_small_blocking_factors<RowMajor>());
|
||||
CALL_SUBTEST(test_tensor_product<ColMajor>());
|
||||
CALL_SUBTEST(test_tensor_product<RowMajor>());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user