mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-07 18:27:40 +08:00
Make Transpositions use evaluators
This commit is contained in:
parent
82b6ac0864
commit
3af4c6c1c9
@ -537,9 +537,6 @@ class PermutationWrapper : public PermutationBase<PermutationWrapper<_IndicesTyp
|
||||
};
|
||||
|
||||
|
||||
// TODO: Do we need to define these operator* functions? Would it be better to have them inherited
|
||||
// from MatrixBase?
|
||||
|
||||
/** \returns the matrix with the permutation applied to the columns.
|
||||
*/
|
||||
template<typename MatrixDerived, typename PermutationDerived>
|
||||
|
@ -929,6 +929,79 @@ struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, PermutationShape,
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/***************************************************************************
|
||||
* Products with transpositions matrices
|
||||
***************************************************************************/
|
||||
|
||||
// FIXME could we unify Transpositions and Permutation into a single "shape"??
|
||||
|
||||
/** \internal
|
||||
* \class transposition_matrix_product
|
||||
* Internal helper class implementing the product between a permutation matrix and a matrix.
|
||||
*/
|
||||
template<typename MatrixType, int Side, bool Transposed, typename MatrixShape>
|
||||
struct transposition_matrix_product
|
||||
{
|
||||
template<typename Dest, typename TranspositionType>
|
||||
static inline void evalTo(Dest& dst, const TranspositionType& tr, const MatrixType& mat)
|
||||
{
|
||||
typedef typename TranspositionType::StorageIndex StorageIndex;
|
||||
const Index size = tr.size();
|
||||
StorageIndex j = 0;
|
||||
|
||||
if(!(is_same<MatrixType,Dest>::value && extract_data(dst) == extract_data(mat)))
|
||||
dst = mat;
|
||||
|
||||
for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k)
|
||||
if(Index(j=tr.coeff(k))!=k)
|
||||
{
|
||||
if(Side==OnTheLeft) dst.row(k).swap(dst.row(j));
|
||||
else if(Side==OnTheRight) dst.col(k).swap(dst.col(j));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
|
||||
struct generic_product_impl<Lhs, Rhs, TranspositionsShape, MatrixShape, ProductTag>
|
||||
{
|
||||
template<typename Dest>
|
||||
static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
|
||||
{
|
||||
transposition_matrix_product<Rhs, OnTheLeft, false, MatrixShape>::run(dst, lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
|
||||
struct generic_product_impl<Lhs, Rhs, MatrixShape, TranspositionsShape, ProductTag>
|
||||
{
|
||||
template<typename Dest>
|
||||
static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
|
||||
{
|
||||
transposition_matrix_product<Lhs, OnTheRight, false, MatrixShape>::run(dst, rhs, lhs);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
|
||||
struct generic_product_impl<Transpose<Lhs>, Rhs, TranspositionsShape, MatrixShape, ProductTag>
|
||||
{
|
||||
template<typename Dest>
|
||||
static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs)
|
||||
{
|
||||
transposition_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedPermutation(), rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
|
||||
struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, TranspositionsShape, ProductTag>
|
||||
{
|
||||
template<typename Dest>
|
||||
static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs)
|
||||
{
|
||||
transposition_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedPermutation(), lhs);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -41,10 +41,6 @@ namespace Eigen {
|
||||
* \sa class PermutationMatrix
|
||||
*/
|
||||
|
||||
namespace internal {
|
||||
template<typename TranspositionType, typename MatrixType, int Side, bool Transposed=false> struct transposition_matrix_product_retval;
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
class TranspositionsBase
|
||||
{
|
||||
@ -325,77 +321,32 @@ class TranspositionsWrapper
|
||||
const typename IndicesType::Nested m_indices;
|
||||
};
|
||||
|
||||
|
||||
|
||||
/** \returns the \a matrix with the \a transpositions applied to the columns.
|
||||
*/
|
||||
template<typename Derived, typename TranspositionsDerived>
|
||||
inline const internal::transposition_matrix_product_retval<TranspositionsDerived, Derived, OnTheRight>
|
||||
operator*(const MatrixBase<Derived>& matrix,
|
||||
const TranspositionsBase<TranspositionsDerived> &transpositions)
|
||||
template<typename MatrixDerived, typename TranspositionsDerived>
|
||||
EIGEN_DEVICE_FUNC
|
||||
const Product<MatrixDerived, TranspositionsDerived, DefaultProduct>
|
||||
operator*(const MatrixBase<MatrixDerived> &matrix,
|
||||
const TranspositionsBase<TranspositionsDerived>& transpositions)
|
||||
{
|
||||
return internal::transposition_matrix_product_retval
|
||||
<TranspositionsDerived, Derived, OnTheRight>
|
||||
(transpositions.derived(), matrix.derived());
|
||||
return Product<MatrixDerived, TranspositionsDerived, DefaultProduct>
|
||||
(matrix.derived(), transpositions.derived());
|
||||
}
|
||||
|
||||
/** \returns the \a matrix with the \a transpositions applied to the rows.
|
||||
*/
|
||||
template<typename Derived, typename TranspositionDerived>
|
||||
inline const internal::transposition_matrix_product_retval
|
||||
<TranspositionDerived, Derived, OnTheLeft>
|
||||
operator*(const TranspositionsBase<TranspositionDerived> &transpositions,
|
||||
const MatrixBase<Derived>& matrix)
|
||||
template<typename TranspositionsDerived, typename MatrixDerived>
|
||||
EIGEN_DEVICE_FUNC
|
||||
const Product<TranspositionsDerived, MatrixDerived, DefaultProduct>
|
||||
operator*(const TranspositionsBase<TranspositionsDerived> &transpositions,
|
||||
const MatrixBase<MatrixDerived>& matrix)
|
||||
{
|
||||
return internal::transposition_matrix_product_retval
|
||||
<TranspositionDerived, Derived, OnTheLeft>
|
||||
(transpositions.derived(), matrix.derived());
|
||||
return Product<TranspositionsDerived, MatrixDerived, DefaultProduct>
|
||||
(transpositions.derived(), matrix.derived());
|
||||
}
|
||||
|
||||
namespace internal {
|
||||
|
||||
template<typename TranspositionType, typename MatrixType, int Side, bool Transposed>
|
||||
struct traits<transposition_matrix_product_retval<TranspositionType, MatrixType, Side, Transposed> >
|
||||
{
|
||||
typedef typename MatrixType::PlainObject ReturnType;
|
||||
};
|
||||
|
||||
template<typename TranspositionType, typename MatrixType, int Side, bool Transposed>
|
||||
struct transposition_matrix_product_retval
|
||||
: public ReturnByValue<transposition_matrix_product_retval<TranspositionType, MatrixType, Side, Transposed> >
|
||||
{
|
||||
typedef typename remove_all<typename MatrixType::Nested>::type MatrixTypeNestedCleaned;
|
||||
typedef typename TranspositionType::StorageIndex StorageIndex;
|
||||
|
||||
transposition_matrix_product_retval(const TranspositionType& tr, const MatrixType& matrix)
|
||||
: m_transpositions(tr), m_matrix(matrix)
|
||||
{}
|
||||
|
||||
inline Index rows() const { return m_matrix.rows(); }
|
||||
inline Index cols() const { return m_matrix.cols(); }
|
||||
|
||||
template<typename Dest> inline void evalTo(Dest& dst) const
|
||||
{
|
||||
const Index size = m_transpositions.size();
|
||||
StorageIndex j = 0;
|
||||
|
||||
if(!(is_same<MatrixTypeNestedCleaned,Dest>::value && extract_data(dst) == extract_data(m_matrix)))
|
||||
dst = m_matrix;
|
||||
|
||||
for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k)
|
||||
if(Index(j=m_transpositions.coeff(k))!=k)
|
||||
{
|
||||
if(Side==OnTheLeft)
|
||||
dst.row(k).swap(dst.row(j));
|
||||
else if(Side==OnTheRight)
|
||||
dst.col(k).swap(dst.col(j));
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
const TranspositionType& m_transpositions;
|
||||
typename MatrixType::Nested m_matrix;
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
/* Template partial specialization for transposed/inverse transpositions */
|
||||
|
||||
@ -412,26 +363,56 @@ class Transpose<TranspositionsBase<TranspositionsDerived> >
|
||||
|
||||
/** \returns the \a matrix with the inverse transpositions applied to the columns.
|
||||
*/
|
||||
template<typename Derived> friend
|
||||
inline const internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheRight, true>
|
||||
operator*(const MatrixBase<Derived>& matrix, const Transpose& trt)
|
||||
template<typename OtherDerived> friend
|
||||
const Product<OtherDerived, Transpose, DefaultProduct>
|
||||
operator*(const MatrixBase<OtherDerived>& matrix, const Transpose& trt)
|
||||
{
|
||||
return internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheRight, true>(trt.m_transpositions, matrix.derived());
|
||||
return Product<OtherDerived, Transpose, DefaultProduct>(matrix.derived(), trt.derived());
|
||||
}
|
||||
|
||||
/** \returns the \a matrix with the inverse transpositions applied to the rows.
|
||||
*/
|
||||
template<typename Derived>
|
||||
inline const internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheLeft, true>
|
||||
operator*(const MatrixBase<Derived>& matrix) const
|
||||
template<typename OtherDerived>
|
||||
const Product<Transpose, OtherDerived, DefaultProduct>
|
||||
operator*(const MatrixBase<OtherDerived>& matrix) const
|
||||
{
|
||||
return internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheLeft, true>(m_transpositions, matrix.derived());
|
||||
return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived());
|
||||
}
|
||||
|
||||
protected:
|
||||
const TranspositionType& m_transpositions;
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
|
||||
// TODO currently a Transpositions expression has the form Transpositions or TranspositionsWrapper
|
||||
// or their transpose; in the future shape should be defined by the expression traits
|
||||
template<int SizeAtCompileTime, int MaxSizeAtCompileTime, typename IndexType>
|
||||
struct evaluator_traits<Transpositions<SizeAtCompileTime, MaxSizeAtCompileTime, IndexType> >
|
||||
{
|
||||
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
|
||||
typedef TranspositionsShape Shape;
|
||||
static const int AssumeAliasing = 0;
|
||||
};
|
||||
|
||||
template<typename IndicesType>
|
||||
struct evaluator_traits<TranspositionsWrapper<IndicesType> >
|
||||
{
|
||||
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
|
||||
typedef TranspositionsShape Shape;
|
||||
static const int AssumeAliasing = 0;
|
||||
};
|
||||
|
||||
template<typename Derived>
|
||||
struct evaluator_traits<Transpose<TranspositionsBase<Derived> > >
|
||||
{
|
||||
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
|
||||
typedef TranspositionsShape Shape;
|
||||
static const int AssumeAliasing = 0;
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_TRANSPOSITIONS_H
|
||||
|
@ -482,6 +482,7 @@ struct BandShape { static std::string debugName() { return "BandSha
|
||||
struct TriangularShape { static std::string debugName() { return "TriangularShape"; } };
|
||||
struct SelfAdjointShape { static std::string debugName() { return "SelfAdjointShape"; } };
|
||||
struct PermutationShape { static std::string debugName() { return "PermutationShape"; } };
|
||||
struct TranspositionsShape { static std::string debugName() { return "TranspositionsShape"; } };
|
||||
struct SparseShape { static std::string debugName() { return "SparseShape"; } };
|
||||
|
||||
namespace internal {
|
||||
|
Loading…
Reference in New Issue
Block a user