Make Transpositions use evaluators

This commit is contained in:
Gael Guennebaud 2015-06-19 11:50:24 +02:00
parent 82b6ac0864
commit 3af4c6c1c9
4 changed files with 128 additions and 76 deletions

View File

@ -537,9 +537,6 @@ class PermutationWrapper : public PermutationBase<PermutationWrapper<_IndicesTyp
};
// TODO: Do we need to define these operator* functions? Would it be better to have them inherited
// from MatrixBase?
/** \returns the matrix with the permutation applied to the columns.
*/
template<typename MatrixDerived, typename PermutationDerived>

View File

@ -929,6 +929,79 @@ struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, PermutationShape,
}
};
/***************************************************************************
* Products with transpositions matrices
***************************************************************************/
// FIXME could we unify Transpositions and Permutation into a single "shape"??
/** \internal
* \class transposition_matrix_product
* Internal helper class implementing the product between a permutation matrix and a matrix.
*/
template<typename MatrixType, int Side, bool Transposed, typename MatrixShape>
struct transposition_matrix_product
{
template<typename Dest, typename TranspositionType>
static inline void evalTo(Dest& dst, const TranspositionType& tr, const MatrixType& mat)
{
typedef typename TranspositionType::StorageIndex StorageIndex;
const Index size = tr.size();
StorageIndex j = 0;
if(!(is_same<MatrixType,Dest>::value && extract_data(dst) == extract_data(mat)))
dst = mat;
for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k)
if(Index(j=tr.coeff(k))!=k)
{
if(Side==OnTheLeft) dst.row(k).swap(dst.row(j));
else if(Side==OnTheRight) dst.col(k).swap(dst.col(j));
}
}
};
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
struct generic_product_impl<Lhs, Rhs, TranspositionsShape, MatrixShape, ProductTag>
{
template<typename Dest>
static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
{
transposition_matrix_product<Rhs, OnTheLeft, false, MatrixShape>::run(dst, lhs, rhs);
}
};
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
struct generic_product_impl<Lhs, Rhs, MatrixShape, TranspositionsShape, ProductTag>
{
template<typename Dest>
static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
{
transposition_matrix_product<Lhs, OnTheRight, false, MatrixShape>::run(dst, rhs, lhs);
}
};
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
struct generic_product_impl<Transpose<Lhs>, Rhs, TranspositionsShape, MatrixShape, ProductTag>
{
template<typename Dest>
static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs)
{
transposition_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedPermutation(), rhs);
}
};
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, TranspositionsShape, ProductTag>
{
template<typename Dest>
static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs)
{
transposition_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedPermutation(), lhs);
}
};
} // end namespace internal
} // end namespace Eigen

View File

@ -41,10 +41,6 @@ namespace Eigen {
* \sa class PermutationMatrix
*/
namespace internal {
template<typename TranspositionType, typename MatrixType, int Side, bool Transposed=false> struct transposition_matrix_product_retval;
}
template<typename Derived>
class TranspositionsBase
{
@ -325,77 +321,32 @@ class TranspositionsWrapper
const typename IndicesType::Nested m_indices;
};
/** \returns the \a matrix with the \a transpositions applied to the columns.
*/
template<typename Derived, typename TranspositionsDerived>
inline const internal::transposition_matrix_product_retval<TranspositionsDerived, Derived, OnTheRight>
operator*(const MatrixBase<Derived>& matrix,
const TranspositionsBase<TranspositionsDerived> &transpositions)
template<typename MatrixDerived, typename TranspositionsDerived>
EIGEN_DEVICE_FUNC
const Product<MatrixDerived, TranspositionsDerived, DefaultProduct>
operator*(const MatrixBase<MatrixDerived> &matrix,
const TranspositionsBase<TranspositionsDerived>& transpositions)
{
return internal::transposition_matrix_product_retval
<TranspositionsDerived, Derived, OnTheRight>
(transpositions.derived(), matrix.derived());
return Product<MatrixDerived, TranspositionsDerived, DefaultProduct>
(matrix.derived(), transpositions.derived());
}
/** \returns the \a matrix with the \a transpositions applied to the rows.
*/
template<typename Derived, typename TranspositionDerived>
inline const internal::transposition_matrix_product_retval
<TranspositionDerived, Derived, OnTheLeft>
operator*(const TranspositionsBase<TranspositionDerived> &transpositions,
const MatrixBase<Derived>& matrix)
template<typename TranspositionsDerived, typename MatrixDerived>
EIGEN_DEVICE_FUNC
const Product<TranspositionsDerived, MatrixDerived, DefaultProduct>
operator*(const TranspositionsBase<TranspositionsDerived> &transpositions,
const MatrixBase<MatrixDerived>& matrix)
{
return internal::transposition_matrix_product_retval
<TranspositionDerived, Derived, OnTheLeft>
(transpositions.derived(), matrix.derived());
return Product<TranspositionsDerived, MatrixDerived, DefaultProduct>
(transpositions.derived(), matrix.derived());
}
namespace internal {
template<typename TranspositionType, typename MatrixType, int Side, bool Transposed>
struct traits<transposition_matrix_product_retval<TranspositionType, MatrixType, Side, Transposed> >
{
typedef typename MatrixType::PlainObject ReturnType;
};
template<typename TranspositionType, typename MatrixType, int Side, bool Transposed>
struct transposition_matrix_product_retval
: public ReturnByValue<transposition_matrix_product_retval<TranspositionType, MatrixType, Side, Transposed> >
{
typedef typename remove_all<typename MatrixType::Nested>::type MatrixTypeNestedCleaned;
typedef typename TranspositionType::StorageIndex StorageIndex;
transposition_matrix_product_retval(const TranspositionType& tr, const MatrixType& matrix)
: m_transpositions(tr), m_matrix(matrix)
{}
inline Index rows() const { return m_matrix.rows(); }
inline Index cols() const { return m_matrix.cols(); }
template<typename Dest> inline void evalTo(Dest& dst) const
{
const Index size = m_transpositions.size();
StorageIndex j = 0;
if(!(is_same<MatrixTypeNestedCleaned,Dest>::value && extract_data(dst) == extract_data(m_matrix)))
dst = m_matrix;
for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k)
if(Index(j=m_transpositions.coeff(k))!=k)
{
if(Side==OnTheLeft)
dst.row(k).swap(dst.row(j));
else if(Side==OnTheRight)
dst.col(k).swap(dst.col(j));
}
}
protected:
const TranspositionType& m_transpositions;
typename MatrixType::Nested m_matrix;
};
} // end namespace internal
/* Template partial specialization for transposed/inverse transpositions */
@ -412,26 +363,56 @@ class Transpose<TranspositionsBase<TranspositionsDerived> >
/** \returns the \a matrix with the inverse transpositions applied to the columns.
*/
template<typename Derived> friend
inline const internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheRight, true>
operator*(const MatrixBase<Derived>& matrix, const Transpose& trt)
template<typename OtherDerived> friend
const Product<OtherDerived, Transpose, DefaultProduct>
operator*(const MatrixBase<OtherDerived>& matrix, const Transpose& trt)
{
return internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheRight, true>(trt.m_transpositions, matrix.derived());
return Product<OtherDerived, Transpose, DefaultProduct>(matrix.derived(), trt.derived());
}
/** \returns the \a matrix with the inverse transpositions applied to the rows.
*/
template<typename Derived>
inline const internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheLeft, true>
operator*(const MatrixBase<Derived>& matrix) const
template<typename OtherDerived>
const Product<Transpose, OtherDerived, DefaultProduct>
operator*(const MatrixBase<OtherDerived>& matrix) const
{
return internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheLeft, true>(m_transpositions, matrix.derived());
return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived());
}
protected:
const TranspositionType& m_transpositions;
};
namespace internal {
// TODO currently a Transpositions expression has the form Transpositions or TranspositionsWrapper
// or their transpose; in the future shape should be defined by the expression traits
template<int SizeAtCompileTime, int MaxSizeAtCompileTime, typename IndexType>
struct evaluator_traits<Transpositions<SizeAtCompileTime, MaxSizeAtCompileTime, IndexType> >
{
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
typedef TranspositionsShape Shape;
static const int AssumeAliasing = 0;
};
template<typename IndicesType>
struct evaluator_traits<TranspositionsWrapper<IndicesType> >
{
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
typedef TranspositionsShape Shape;
static const int AssumeAliasing = 0;
};
template<typename Derived>
struct evaluator_traits<Transpose<TranspositionsBase<Derived> > >
{
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
typedef TranspositionsShape Shape;
static const int AssumeAliasing = 0;
};
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_TRANSPOSITIONS_H

View File

@ -482,6 +482,7 @@ struct BandShape { static std::string debugName() { return "BandSha
struct TriangularShape { static std::string debugName() { return "TriangularShape"; } };
struct SelfAdjointShape { static std::string debugName() { return "SelfAdjointShape"; } };
struct PermutationShape { static std::string debugName() { return "PermutationShape"; } };
struct TranspositionsShape { static std::string debugName() { return "TranspositionsShape"; } };
struct SparseShape { static std::string debugName() { return "SparseShape"; } };
namespace internal {