Fixed Sparse-Sparse Product in case of mixed StorageIndex types

This commit is contained in:
Erik Schultheis 2021-11-18 18:33:31 +00:00 committed by Rasmus Munk Larsen
parent 96aeffb013
commit b0fb5417d3
3 changed files with 88 additions and 25 deletions

View File

@ -126,6 +126,11 @@ static void conservative_sparse_sparse_product_impl(const Lhs& lhs, const Rhs& r
namespace internal {
// Helper template to generate new sparse matrix types
template<class Source, int Order>
using WithStorageOrder = SparseMatrix<typename Source::Scalar, Order, typename Source::StorageIndex>;
template<typename Lhs, typename Rhs, typename ResultType,
int LhsStorageOrder = (traits<Lhs>::Flags&RowMajorBit) ? RowMajor : ColMajor,
int RhsStorageOrder = (traits<Rhs>::Flags&RowMajorBit) ? RowMajor : ColMajor,
@ -140,15 +145,15 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,C
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorMatrix;
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrixAux;
typedef typename sparse_eval<ColMajorMatrixAux,ResultType::RowsAtCompileTime,ResultType::ColsAtCompileTime,ColMajorMatrixAux::Flags>::type ColMajorMatrix;
using RowMajorMatrix = WithStorageOrder<ResultType, RowMajor>;
using ColMajorMatrixAux = WithStorageOrder<ResultType, ColMajor>;
// If the result is tall and thin (in the extreme case a column vector)
// then it is faster to sort the coefficients inplace instead of transposing twice.
// FIXME, the following heuristic is probably not very good.
if(lhs.rows()>rhs.cols())
{
using ColMajorMatrix = typename sparse_eval<ColMajorMatrixAux,ResultType::RowsAtCompileTime,ResultType::ColsAtCompileTime,ColMajorMatrixAux::Flags>::type;
ColMajorMatrix resCol(lhs.rows(),rhs.cols());
// perform sorted insertion
internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol, true);
@ -170,8 +175,8 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,C
{
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
typedef SparseMatrix<typename Rhs::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorRhs;
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorRes;
using RowMajorRhs = WithStorageOrder<Rhs, RowMajor>;
using RowMajorRes = WithStorageOrder<ResultType, RowMajor>;
RowMajorRhs rhsRow = rhs;
RowMajorRes resRow(lhs.rows(), rhs.cols());
internal::conservative_sparse_sparse_product_impl<RowMajorRhs,Lhs,RowMajorRes>(rhsRow, lhs, resRow);
@ -184,8 +189,8 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,R
{
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
typedef SparseMatrix<typename Lhs::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorLhs;
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorRes;
using RowMajorLhs = WithStorageOrder<Lhs, RowMajor>;
using RowMajorRes = WithStorageOrder<ResultType, RowMajor>;
RowMajorLhs lhsRow = lhs;
RowMajorRes resRow(lhs.rows(), rhs.cols());
internal::conservative_sparse_sparse_product_impl<Rhs,RowMajorLhs,RowMajorRes>(rhs, lhsRow, resRow);
@ -198,9 +203,9 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,R
{
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorMatrix;
RowMajorMatrix resRow(lhs.rows(), rhs.cols());
internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow);
using RowMajorRes = WithStorageOrder<ResultType, RowMajor>;
RowMajorRes resRow(lhs.rows(), rhs.cols());
internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorRes>(rhs, lhs, resRow);
res = resRow;
}
};
@ -213,9 +218,9 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,C
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
ColMajorMatrix resCol(lhs.rows(), rhs.cols());
internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol);
using ColMajorRes = WithStorageOrder<ResultType, ColMajor>;
ColMajorRes resCol(lhs.rows(), rhs.cols());
internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorRes>(lhs, rhs, resCol);
res = resCol;
}
};
@ -225,8 +230,8 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,C
{
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorLhs;
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRes;
using ColMajorLhs = WithStorageOrder<Lhs, ColMajor>;
using ColMajorRes = WithStorageOrder<ResultType, ColMajor>;
ColMajorLhs lhsCol = lhs;
ColMajorRes resCol(lhs.rows(), rhs.cols());
internal::conservative_sparse_sparse_product_impl<ColMajorLhs,Rhs,ColMajorRes>(lhsCol, rhs, resCol);
@ -239,8 +244,8 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,R
{
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRhs;
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRes;
using ColMajorRhs = WithStorageOrder<Rhs, ColMajor>;
using ColMajorRes = WithStorageOrder<ResultType, ColMajor>;
ColMajorRhs rhsCol = rhs;
ColMajorRes resCol(lhs.rows(), rhs.cols());
internal::conservative_sparse_sparse_product_impl<Lhs,ColMajorRhs,ColMajorRes>(lhs, rhsCol, resCol);
@ -253,12 +258,12 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,R
{
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorMatrix;
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
RowMajorMatrix resRow(lhs.rows(),rhs.cols());
internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow);
using ColMajorRes = WithStorageOrder<ResultType, ColMajor>;
using RowMajorRes = WithStorageOrder<ResultType, RowMajor>;
RowMajorRes resRow(lhs.rows(),rhs.cols());
internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorRes>(rhs, lhs, resRow);
// sort the non zeros:
ColMajorMatrix resCol(resRow);
ColMajorRes resCol(resRow);
res = resCol;
}
};
@ -319,7 +324,7 @@ struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMa
{
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorLhs;
using ColMajorLhs = WithStorageOrder<Lhs, ColMajor>;
ColMajorLhs lhsCol(lhs);
internal::sparse_sparse_to_dense_product_impl<ColMajorLhs,Rhs,ResultType>(lhsCol, rhs, res);
}
@ -330,7 +335,7 @@ struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMa
{
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRhs;
using ColMajorRhs = WithStorageOrder<Rhs, ColMajor>;
ColMajorRhs rhsCol(rhs);
internal::sparse_sparse_to_dense_product_impl<Lhs,ColMajorRhs,ResultType>(lhs, rhsCol, res);
}

View File

@ -59,7 +59,8 @@ initSparse(double density,
sparseMat.setZero();
//sparseMat.reserve(int(refMat.rows()*refMat.cols()*density));
sparseMat.reserve(VectorXi::Constant(IsRowMajor ? refMat.rows() : refMat.cols(), int((1.5*density)*(IsRowMajor?refMat.cols():refMat.rows()))));
Index insert_count = 0;
for(Index j=0; j<sparseMat.outerSize(); j++)
{
//sparseMat.startVec(j);
@ -89,6 +90,7 @@ initSparse(double density,
{
//sparseMat.insertBackByOuterInner(j,i) = v;
sparseMat.insertByOuterInner(j,i) = v;
++insert_count;
if (nonzeroCoords)
nonzeroCoords->push_back(Matrix<StorageIndex,2,1> (ai,aj));
}
@ -97,6 +99,9 @@ initSparse(double density,
zeroCoords->push_back(Matrix<StorageIndex,2,1> (ai,aj));
}
refMat(ai,aj) = v;
// make sure we only insert as many as the sparse matrix supports
if(insert_count == NumTraits<StorageIndex>::highest()) return;
}
}
//sparseMat.finalize();

View File

@ -461,6 +461,58 @@ void test_mixing_types()
VERIFY_IS_APPROX( dC2 = sC1 * dR1.col(0), dC3 = sC1 * dR1.template cast<Cplx>().col(0) );
}
// Test mixed storage types
template<int OrderA, int OrderB, int OrderC>
void test_mixed_storage_imp() {
typedef float Real;
typedef Matrix<Real,Dynamic,Dynamic> DenseMat;
// Case: Large inputs but small result
{
SparseMatrix<Real, OrderA> A(8, 512);
SparseMatrix<Real, OrderB> B(512, 8);
DenseMat refA(8, 512);
DenseMat refB(512, 8);
initSparse<Real>(0.1, refA, A);
initSparse<Real>(0.1, refB, B);
SparseMatrix<Real, OrderC, std::int8_t> result;
SparseMatrix<Real, OrderC> result_large;
DenseMat refResult;
VERIFY_IS_APPROX( result = (A * B), refResult = refA * refB );
}
// Case: Small input but large result
{
SparseMatrix<Real, OrderA, std::int8_t> A(127, 8);
SparseMatrix<Real, OrderB, std::int8_t> B(8, 127);
DenseMat refA(127, 8);
DenseMat refB(8, 127);
initSparse<Real>(0.01, refA, A);
initSparse<Real>(0.01, refB, B);
SparseMatrix<Real, OrderC> result;
SparseMatrix<Real, OrderC> result_large;
DenseMat refResult;
VERIFY_IS_APPROX( result = (A * B), refResult = refA * refB );
}
}
void test_mixed_storage() {
test_mixed_storage_imp<RowMajor, RowMajor, RowMajor>();
test_mixed_storage_imp<RowMajor, RowMajor, ColMajor>();
test_mixed_storage_imp<RowMajor, ColMajor, RowMajor>();
test_mixed_storage_imp<RowMajor, ColMajor, ColMajor>();
test_mixed_storage_imp<ColMajor, RowMajor, RowMajor>();
test_mixed_storage_imp<ColMajor, RowMajor, ColMajor>();
test_mixed_storage_imp<ColMajor, ColMajor, RowMajor>();
test_mixed_storage_imp<ColMajor, ColMajor, ColMajor>();
}
EIGEN_DECLARE_TEST(sparse_product)
{
for(int i = 0; i < g_repeat; i++) {
@ -473,5 +525,6 @@ EIGEN_DECLARE_TEST(sparse_product)
CALL_SUBTEST_4( (sparse_product_regression_test<SparseMatrix<double,RowMajor>, Matrix<double, Dynamic, Dynamic, RowMajor> >()) );
CALL_SUBTEST_5( (test_mixing_types<float>()) );
CALL_SUBTEST_5( (test_mixed_storage()) );
}
}