unsupported/TensorSymmetry: factor out completely from Tensor module

Remove the symCoeff() method of the the Tensor module and move the
functionality into a new operator() of the symmetry classes. This makes
the Tensor module now completely self-contained without symmetry
support (even though previously it was only a forward declaration and a
otherwise harmless trivial templated method) and also removes the
inconsistency with the rest of eigen w.r.t. the method's naming scheme.
This commit is contained in:
Christian Seiler 2014-06-04 20:44:22 +02:00
parent ea99433523
commit 96cb58fa3b
5 changed files with 32 additions and 21 deletions

View File

@ -91,9 +91,6 @@ struct tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor>
return std_array_get<RowMajor ? 0 : NumIndices - 1>(indices);
}
};
/* Forward-declaration required for the symmetry support. */
template<typename Tensor_, typename Symmetry_, int Flags = 0> class tensor_symmetry_value_setter;
} // end namespace internal
template<typename Scalar_, std::size_t NumIndices_, int Options_>
@ -285,18 +282,6 @@ class Tensor
#endif
}
template<typename Symmetry_, typename... IndexTypes>
internal::tensor_symmetry_value_setter<Self, Symmetry_> symCoeff(const Symmetry_& symmetry, Index firstIndex, IndexTypes... otherIndices)
{
return symCoeff(symmetry, std::array<Index, NumIndices>{{firstIndex, otherIndices...}});
}
template<typename Symmetry_, typename... IndexTypes>
internal::tensor_symmetry_value_setter<Self, Symmetry_> symCoeff(const Symmetry_& symmetry, std::array<Index, NumIndices> const& indices)
{
return internal::tensor_symmetry_value_setter<Self, Symmetry_>(*this, symmetry, indices);
}
protected:
bool checkIndexRange(const std::array<Index, NumIndices>& indices) const
{

View File

@ -50,6 +50,19 @@ class DynamicSGroup
inline int globalFlags() const { return m_globalFlags; }
inline std::size_t size() const { return m_elements.size(); }
template<typename Tensor_, typename... IndexTypes>
inline internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup> operator()(Tensor_& tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const
{
static_assert(sizeof...(otherIndices) + 1 == Tensor_::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}});
}
template<typename Tensor_>
inline internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup> operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices> const& indices) const
{
return internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup>(tensor, *this, indices);
}
private:
struct GroupElement {
std::vector<int> representation;

View File

@ -212,6 +212,19 @@ class StaticSGroup
return ge::count;
}
constexpr static inline int globalFlags() { return group_elements::global_flags; }
template<typename Tensor_, typename... IndexTypes>
inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const
{
static_assert(sizeof...(otherIndices) + 1 == Tensor_::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}});
}
template<typename Tensor_>
inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices> const& indices) const
{
return internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>>(tensor, *this, indices);
}
};
} // end namespace Eigen

View File

@ -293,7 +293,7 @@ struct tensor_symmetry_calculate_flags
}
};
template<typename Tensor_, typename Symmetry_, int Flags>
template<typename Tensor_, typename Symmetry_, int Flags = 0>
class tensor_symmetry_value_setter
{
public:

View File

@ -661,7 +661,7 @@ static void test_tensor_epsilon()
Tensor<int, 3> epsilon(3,3,3);
epsilon.setZero();
epsilon.symCoeff(sym, 0, 1, 2) = 1;
sym(epsilon, 0, 1, 2) = 1;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
@ -683,7 +683,7 @@ static void test_tensor_sym()
for (int k = l; k < 10; k++) {
for (int j = 0; j < 10; j++) {
for (int i = j; i < 10; i++) {
t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l);
sym(t, i, j, k, l) = (i + j) * (k + l);
}
}
}
@ -712,7 +712,7 @@ static void test_tensor_asym()
for (int k = l + 1; k < 10; k++) {
for (int j = 0; j < 10; j++) {
for (int i = j + 1; i < 10; i++) {
t.symCoeff(sym, i, j, k, l) = ((i * j) + (k * l));
sym(t, i, j, k, l) = ((i * j) + (k * l));
}
}
}
@ -751,7 +751,7 @@ static void test_tensor_dynsym()
for (int k = l; k < 10; k++) {
for (int j = 0; j < 10; j++) {
for (int i = j; i < 10; i++) {
t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l);
sym(t, i, j, k, l) = (i + j) * (k + l);
}
}
}
@ -787,7 +787,7 @@ static void test_tensor_randacc()
std::swap(i, j);
if (k < l)
std::swap(k, l);
t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l);
sym(t, i, j, k, l) = (i + j) * (k + l);
}
for (int l = 0; l < 10; l++) {