Fix permutation/transposiitons products wrt nested_eval

This commit is contained in:
Gael Guennebaud 2015-06-19 16:37:04 +02:00
parent 0c8b0e007b
commit 5c84dd5665

View File

@ -842,17 +842,19 @@ struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DenseShape,
* Internal helper class implementing the product between a permutation matrix and a matrix.
* This class is specialized for DenseShape below and for SparseShape in SparseCore/SparsePermutation.h
*/
template<typename MatrixType, int Side, bool Transposed, typename MatrixShape>
template<typename ExpressionType, int Side, bool Transposed, typename ExpressionShape>
struct permutation_matrix_product;
template<typename MatrixType, int Side, bool Transposed>
struct permutation_matrix_product<MatrixType, Side, Transposed, DenseShape>
template<typename ExpressionType, int Side, bool Transposed>
struct permutation_matrix_product<ExpressionType, Side, Transposed, DenseShape>
{
typedef typename nested_eval<ExpressionType, 1>::type MatrixType;
typedef typename remove_all<MatrixType>::type MatrixTypeCleaned;
template<typename Dest, typename PermutationType>
static inline void run(Dest& dst, const PermutationType& perm, const MatrixType& mat)
static inline void run(Dest& dst, const PermutationType& perm, const ExpressionType& xpr)
{
MatrixType mat(xpr);
const Index n = Side==OnTheLeft ? mat.rows() : mat.cols();
// FIXME we need an is_same for expression that is not sensitive to constness. For instance
// is_same_xpr<Block<const Matrix>, Block<Matrix> >::value should be true.
@ -893,7 +895,7 @@ struct permutation_matrix_product<MatrixType, Side, Transposed, DenseShape>
=
Block<const MatrixTypeCleaned,Side==OnTheLeft ? 1 : MatrixType::RowsAtCompileTime,Side==OnTheRight ? 1 : MatrixType::ColsAtCompileTime>
Block<const MatrixTypeCleaned,Side==OnTheLeft ? 1 : MatrixTypeCleaned::RowsAtCompileTime,Side==OnTheRight ? 1 : MatrixTypeCleaned::ColsAtCompileTime>
(mat, ((Side==OnTheRight) ^ Transposed) ? perm.indices().coeff(i) : i);
}
}
@ -951,26 +953,30 @@ struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, PermutationShape,
* \class transposition_matrix_product
* Internal helper class implementing the product between a permutation matrix and a matrix.
*/
template<typename MatrixType, int Side, bool Transposed, typename MatrixShape>
template<typename ExpressionType, int Side, bool Transposed, typename ExpressionShape>
struct transposition_matrix_product
{
template<typename Dest, typename TranspositionType>
static inline void run(Dest& dst, const TranspositionType& tr, const MatrixType& mat)
{
typedef typename TranspositionType::StorageIndex StorageIndex;
const Index size = tr.size();
StorageIndex j = 0;
typedef typename nested_eval<ExpressionType, 1>::type MatrixType;
typedef typename remove_all<MatrixType>::type MatrixTypeCleaned;
template<typename Dest, typename TranspositionType>
static inline void run(Dest& dst, const TranspositionType& tr, const ExpressionType& xpr)
{
MatrixType mat(xpr);
typedef typename TranspositionType::StorageIndex StorageIndex;
const Index size = tr.size();
StorageIndex j = 0;
if(!(is_same<MatrixType,Dest>::value && extract_data(dst) == extract_data(mat)))
dst = mat;
if(!(is_same<MatrixTypeCleaned,Dest>::value && extract_data(dst) == extract_data(mat)))
dst = mat;
for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k)
if(Index(j=tr.coeff(k))!=k)
{
if(Side==OnTheLeft) dst.row(k).swap(dst.row(j));
else if(Side==OnTheRight) dst.col(k).swap(dst.col(j));
}
}
for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k)
if(Index(j=tr.coeff(k))!=k)
{
if(Side==OnTheLeft) dst.row(k).swap(dst.row(j));
else if(Side==OnTheRight) dst.col(k).swap(dst.col(j));
}
}
};
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>