mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-30 17:40:05 +08:00
unsupported/TensorSymmetry: factor out completely from Tensor module
Remove the symCoeff() method of the the Tensor module and move the functionality into a new operator() of the symmetry classes. This makes the Tensor module now completely self-contained without symmetry support (even though previously it was only a forward declaration and a otherwise harmless trivial templated method) and also removes the inconsistency with the rest of eigen w.r.t. the method's naming scheme.
This commit is contained in:
parent
ea99433523
commit
96cb58fa3b
@ -91,9 +91,6 @@ struct tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor>
|
||||
return std_array_get<RowMajor ? 0 : NumIndices - 1>(indices);
|
||||
}
|
||||
};
|
||||
|
||||
/* Forward-declaration required for the symmetry support. */
|
||||
template<typename Tensor_, typename Symmetry_, int Flags = 0> class tensor_symmetry_value_setter;
|
||||
} // end namespace internal
|
||||
|
||||
template<typename Scalar_, std::size_t NumIndices_, int Options_>
|
||||
@ -285,18 +282,6 @@ class Tensor
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename Symmetry_, typename... IndexTypes>
|
||||
internal::tensor_symmetry_value_setter<Self, Symmetry_> symCoeff(const Symmetry_& symmetry, Index firstIndex, IndexTypes... otherIndices)
|
||||
{
|
||||
return symCoeff(symmetry, std::array<Index, NumIndices>{{firstIndex, otherIndices...}});
|
||||
}
|
||||
|
||||
template<typename Symmetry_, typename... IndexTypes>
|
||||
internal::tensor_symmetry_value_setter<Self, Symmetry_> symCoeff(const Symmetry_& symmetry, std::array<Index, NumIndices> const& indices)
|
||||
{
|
||||
return internal::tensor_symmetry_value_setter<Self, Symmetry_>(*this, symmetry, indices);
|
||||
}
|
||||
|
||||
protected:
|
||||
bool checkIndexRange(const std::array<Index, NumIndices>& indices) const
|
||||
{
|
||||
|
@ -50,6 +50,19 @@ class DynamicSGroup
|
||||
|
||||
inline int globalFlags() const { return m_globalFlags; }
|
||||
inline std::size_t size() const { return m_elements.size(); }
|
||||
|
||||
template<typename Tensor_, typename... IndexTypes>
|
||||
inline internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup> operator()(Tensor_& tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const
|
||||
{
|
||||
static_assert(sizeof...(otherIndices) + 1 == Tensor_::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
|
||||
return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}});
|
||||
}
|
||||
|
||||
template<typename Tensor_>
|
||||
inline internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup> operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices> const& indices) const
|
||||
{
|
||||
return internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup>(tensor, *this, indices);
|
||||
}
|
||||
private:
|
||||
struct GroupElement {
|
||||
std::vector<int> representation;
|
||||
|
@ -212,6 +212,19 @@ class StaticSGroup
|
||||
return ge::count;
|
||||
}
|
||||
constexpr static inline int globalFlags() { return group_elements::global_flags; }
|
||||
|
||||
template<typename Tensor_, typename... IndexTypes>
|
||||
inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const
|
||||
{
|
||||
static_assert(sizeof...(otherIndices) + 1 == Tensor_::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
|
||||
return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}});
|
||||
}
|
||||
|
||||
template<typename Tensor_>
|
||||
inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices> const& indices) const
|
||||
{
|
||||
return internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>>(tensor, *this, indices);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -293,7 +293,7 @@ struct tensor_symmetry_calculate_flags
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Tensor_, typename Symmetry_, int Flags>
|
||||
template<typename Tensor_, typename Symmetry_, int Flags = 0>
|
||||
class tensor_symmetry_value_setter
|
||||
{
|
||||
public:
|
||||
|
@ -661,7 +661,7 @@ static void test_tensor_epsilon()
|
||||
Tensor<int, 3> epsilon(3,3,3);
|
||||
|
||||
epsilon.setZero();
|
||||
epsilon.symCoeff(sym, 0, 1, 2) = 1;
|
||||
sym(epsilon, 0, 1, 2) = 1;
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 3; j++) {
|
||||
@ -683,7 +683,7 @@ static void test_tensor_sym()
|
||||
for (int k = l; k < 10; k++) {
|
||||
for (int j = 0; j < 10; j++) {
|
||||
for (int i = j; i < 10; i++) {
|
||||
t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l);
|
||||
sym(t, i, j, k, l) = (i + j) * (k + l);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -712,7 +712,7 @@ static void test_tensor_asym()
|
||||
for (int k = l + 1; k < 10; k++) {
|
||||
for (int j = 0; j < 10; j++) {
|
||||
for (int i = j + 1; i < 10; i++) {
|
||||
t.symCoeff(sym, i, j, k, l) = ((i * j) + (k * l));
|
||||
sym(t, i, j, k, l) = ((i * j) + (k * l));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -751,7 +751,7 @@ static void test_tensor_dynsym()
|
||||
for (int k = l; k < 10; k++) {
|
||||
for (int j = 0; j < 10; j++) {
|
||||
for (int i = j; i < 10; i++) {
|
||||
t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l);
|
||||
sym(t, i, j, k, l) = (i + j) * (k + l);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -787,7 +787,7 @@ static void test_tensor_randacc()
|
||||
std::swap(i, j);
|
||||
if (k < l)
|
||||
std::swap(k, l);
|
||||
t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l);
|
||||
sym(t, i, j, k, l) = (i + j) * (k + l);
|
||||
}
|
||||
|
||||
for (int l = 0; l < 10; l++) {
|
||||
|
Loading…
Reference in New Issue
Block a user