mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-13 18:37:27 +08:00
[Sparse] Added regression tests for the two bugfixes, the code passes all sparse_product tests
This commit is contained in:
parent
13867c15cc
commit
11e253bc10
@ -169,14 +169,31 @@ class SparseTimeDenseProduct
|
||||
enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit };
|
||||
for(Index j=0; j<m_lhs.outerSize(); ++j)
|
||||
{
|
||||
typename Rhs::Scalar rhs_j = alpha * m_rhs.coeff(j,0);
|
||||
Block<Dest,1,Dest::ColsAtCompileTime> dest_j(dest.row(LhsIsRowMajor ? j : 0));
|
||||
for(LhsInnerIterator it(m_lhs,j); it ;++it)
|
||||
{
|
||||
if(LhsIsRowMajor) dest_j += (alpha*it.value()) * m_rhs.row(it.index());
|
||||
else if(Rhs::ColsAtCompileTime==1) dest.coeffRef(it.index()) += it.value() * rhs_j;
|
||||
else dest.row(it.index()) += (alpha*it.value()) * m_rhs.row(j);
|
||||
}
|
||||
if(LhsIsRowMajor)
|
||||
{
|
||||
// Block<Dest,1,Dest::ColsAtCompileTime> dest_j(dest.row(LhsIsRowMajor ? j : 0)); // this does not work in all cases. Why?
|
||||
Block<Dest,1,Dest::ColsAtCompileTime> dest_j(dest, LhsIsRowMajor ? j : 0);
|
||||
for(LhsInnerIterator it(m_lhs,j); it ;++it)
|
||||
{
|
||||
dest_j += (alpha*it.value()) * m_rhs.row(it.index());
|
||||
}
|
||||
}
|
||||
else if(Rhs::ColsAtCompileTime==1)
|
||||
{
|
||||
typename Rhs::Scalar rhs_j = alpha * m_rhs.coeff(j,0);
|
||||
for(LhsInnerIterator it(m_lhs,j); it ;++it)
|
||||
{
|
||||
dest.coeffRef(it.index()) += it.value() * rhs_j;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(LhsInnerIterator it(m_lhs,j); it ;++it)
|
||||
{
|
||||
dest.row(it.index()) += (alpha*it.value()) * m_rhs.row(j);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -187,8 +204,8 @@ class SparseTimeDenseProduct
|
||||
|
||||
// dense = dense * sparse
|
||||
namespace internal {
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct traits<DenseTimeSparseProduct<Lhs,Rhs> >
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct traits<DenseTimeSparseProduct<Lhs,Rhs> >
|
||||
: traits<ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> >
|
||||
{
|
||||
typedef Dense StorageKind;
|
||||
|
@ -137,6 +137,31 @@ template<typename SparseMatrixType> void sparse_product(const SparseMatrixType&
|
||||
}
|
||||
}
|
||||
|
||||
// New test for Bug in SparseTimeDenseProduct
|
||||
template<typename SparseMatrixType, typename DenseMatrixType> void sparse_product_regression_test()
|
||||
{
|
||||
// This code does not compile with afflicted versions of the bug
|
||||
/* SparseMatrixType sm1(3,2);
|
||||
DenseMatrixType m2(2,2);
|
||||
sm1.setZero();
|
||||
m2.setZero();
|
||||
|
||||
DenseMatrixType m3 = sm1*m2;
|
||||
*/
|
||||
|
||||
|
||||
// This code produces a segfault with afflicted versions of another SparseTimeDenseProduct
|
||||
// bug
|
||||
|
||||
SparseMatrixType sm2(20000,2);
|
||||
DenseMatrixType m3(2,2);
|
||||
sm2.setZero();
|
||||
m3.setZero();
|
||||
DenseMatrixType m4(sm2*m3);
|
||||
|
||||
VERIFY_IS_APPROX( m4(0,0), 0.0 );
|
||||
}
|
||||
|
||||
void test_sparse_product()
|
||||
{
|
||||
for(int i = 0; i < g_repeat; i++) {
|
||||
@ -145,5 +170,7 @@ void test_sparse_product()
|
||||
CALL_SUBTEST_1( sparse_product(SparseMatrix<double>(33, 33)) );
|
||||
|
||||
CALL_SUBTEST_3( sparse_product(DynamicSparseMatrix<double>(8, 8)) );
|
||||
|
||||
CALL_SUBTEST_4( (sparse_product_regression_test<SparseMatrix<double,RowMajor>, Matrix<double, Dynamic, Dynamic, RowMajor> >()) );
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user