mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-02-05 17:50:26 +08:00
Fix permutation/transposiitons products wrt nested_eval
This commit is contained in:
parent
0c8b0e007b
commit
5c84dd5665
@ -842,17 +842,19 @@ struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DenseShape,
|
||||
* Internal helper class implementing the product between a permutation matrix and a matrix.
|
||||
* This class is specialized for DenseShape below and for SparseShape in SparseCore/SparsePermutation.h
|
||||
*/
|
||||
template<typename MatrixType, int Side, bool Transposed, typename MatrixShape>
|
||||
template<typename ExpressionType, int Side, bool Transposed, typename ExpressionShape>
|
||||
struct permutation_matrix_product;
|
||||
|
||||
template<typename MatrixType, int Side, bool Transposed>
|
||||
struct permutation_matrix_product<MatrixType, Side, Transposed, DenseShape>
|
||||
template<typename ExpressionType, int Side, bool Transposed>
|
||||
struct permutation_matrix_product<ExpressionType, Side, Transposed, DenseShape>
|
||||
{
|
||||
typedef typename nested_eval<ExpressionType, 1>::type MatrixType;
|
||||
typedef typename remove_all<MatrixType>::type MatrixTypeCleaned;
|
||||
|
||||
template<typename Dest, typename PermutationType>
|
||||
static inline void run(Dest& dst, const PermutationType& perm, const MatrixType& mat)
|
||||
static inline void run(Dest& dst, const PermutationType& perm, const ExpressionType& xpr)
|
||||
{
|
||||
MatrixType mat(xpr);
|
||||
const Index n = Side==OnTheLeft ? mat.rows() : mat.cols();
|
||||
// FIXME we need an is_same for expression that is not sensitive to constness. For instance
|
||||
// is_same_xpr<Block<const Matrix>, Block<Matrix> >::value should be true.
|
||||
@ -893,7 +895,7 @@ struct permutation_matrix_product<MatrixType, Side, Transposed, DenseShape>
|
||||
|
||||
=
|
||||
|
||||
Block<const MatrixTypeCleaned,Side==OnTheLeft ? 1 : MatrixType::RowsAtCompileTime,Side==OnTheRight ? 1 : MatrixType::ColsAtCompileTime>
|
||||
Block<const MatrixTypeCleaned,Side==OnTheLeft ? 1 : MatrixTypeCleaned::RowsAtCompileTime,Side==OnTheRight ? 1 : MatrixTypeCleaned::ColsAtCompileTime>
|
||||
(mat, ((Side==OnTheRight) ^ Transposed) ? perm.indices().coeff(i) : i);
|
||||
}
|
||||
}
|
||||
@ -951,26 +953,30 @@ struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, PermutationShape,
|
||||
* \class transposition_matrix_product
|
||||
* Internal helper class implementing the product between a permutation matrix and a matrix.
|
||||
*/
|
||||
template<typename MatrixType, int Side, bool Transposed, typename MatrixShape>
|
||||
template<typename ExpressionType, int Side, bool Transposed, typename ExpressionShape>
|
||||
struct transposition_matrix_product
|
||||
{
|
||||
template<typename Dest, typename TranspositionType>
|
||||
static inline void run(Dest& dst, const TranspositionType& tr, const MatrixType& mat)
|
||||
{
|
||||
typedef typename TranspositionType::StorageIndex StorageIndex;
|
||||
const Index size = tr.size();
|
||||
StorageIndex j = 0;
|
||||
typedef typename nested_eval<ExpressionType, 1>::type MatrixType;
|
||||
typedef typename remove_all<MatrixType>::type MatrixTypeCleaned;
|
||||
|
||||
template<typename Dest, typename TranspositionType>
|
||||
static inline void run(Dest& dst, const TranspositionType& tr, const ExpressionType& xpr)
|
||||
{
|
||||
MatrixType mat(xpr);
|
||||
typedef typename TranspositionType::StorageIndex StorageIndex;
|
||||
const Index size = tr.size();
|
||||
StorageIndex j = 0;
|
||||
|
||||
if(!(is_same<MatrixType,Dest>::value && extract_data(dst) == extract_data(mat)))
|
||||
dst = mat;
|
||||
if(!(is_same<MatrixTypeCleaned,Dest>::value && extract_data(dst) == extract_data(mat)))
|
||||
dst = mat;
|
||||
|
||||
for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k)
|
||||
if(Index(j=tr.coeff(k))!=k)
|
||||
{
|
||||
if(Side==OnTheLeft) dst.row(k).swap(dst.row(j));
|
||||
else if(Side==OnTheRight) dst.col(k).swap(dst.col(j));
|
||||
}
|
||||
}
|
||||
for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k)
|
||||
if(Index(j=tr.coeff(k))!=k)
|
||||
{
|
||||
if(Side==OnTheLeft) dst.row(k).swap(dst.row(j));
|
||||
else if(Side==OnTheRight) dst.col(k).swap(dst.col(j));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
|
||||
|
Loading…
Reference in New Issue
Block a user