Extend sparse*sparse product unit test to check that the expected implementation is used (conservative vs auto pruning).

This commit is contained in:
Gael Guennebaud 2016-05-18 16:50:54 +02:00
parent 548a487800
commit 1fa15ceee6

View File

@ -7,8 +7,26 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
static long int nb_temporaries;
inline void on_temporary_creation() {
// here's a great place to set a breakpoint when debugging failures in this test!
nb_temporaries++;
}
#define EIGEN_SPARSE_CREATE_TEMPORARY_PLUGIN { on_temporary_creation(); }
#include "sparse.h"
#define VERIFY_EVALUATION_COUNT(XPR,N) {\
nb_temporaries = 0; \
CALL_SUBTEST( XPR ); \
if(nb_temporaries!=N) std::cerr << "nb_temporaries == " << nb_temporaries << "\n"; \
VERIFY( (#XPR) && nb_temporaries==N ); \
}
template<typename SparseMatrixType> void sparse_product()
{
typedef typename SparseMatrixType::StorageIndex StorageIndex;
@ -76,6 +94,24 @@ template<typename SparseMatrixType> void sparse_product()
VERIFY_IS_APPROX(m4=(m2t.transpose()*m3t.transpose()).pruned(0), refMat4=refMat2t.transpose()*refMat3t.transpose());
VERIFY_IS_APPROX(m4=(m2*m3t.transpose()).pruned(0), refMat4=refMat2*refMat3t.transpose());
// make sure the right product implementation is called:
if((!SparseMatrixType::IsRowMajor) && m2.rows()<=m3.cols())
{
VERIFY_EVALUATION_COUNT(m4 = m2*m3, 3); // 1 temp for the result + 2 for transposing and get a sorted result.
VERIFY_EVALUATION_COUNT(m4 = (m2*m3).pruned(0), 1);
VERIFY_EVALUATION_COUNT(m4 = (m2*m3).eval().pruned(0), 4);
}
// and that pruning is effective:
{
DenseMatrix Ad(2,2);
Ad << -1, 1, 1, 1;
SparseMatrixType As(Ad.sparseView()), B(2,2);
VERIFY_IS_EQUAL( (As*As.transpose()).eval().nonZeros(), 4);
VERIFY_IS_EQUAL( (Ad*Ad.transpose()).eval().sparseView().eval().nonZeros(), 2);
VERIFY_IS_EQUAL( (As*As.transpose()).pruned(1e-6).eval().nonZeros(), 2);
}
// dense ?= sparse * sparse
VERIFY_IS_APPROX(dm4 =m2*m3, refMat4 =refMat2*refMat3);
VERIFY_IS_APPROX(dm4+=m2*m3, refMat4+=refMat2*refMat3);