Update custom setFromTripplets API to allow passing a functor object, and add a collapseDuplicates method to cleanup the API. Also add respective unit test

This commit is contained in:
Gael Guennebaud 2015-10-13 11:30:41 +02:00
parent b9d81c9150
commit b4c79ee1d3
2 changed files with 38 additions and 19 deletions

View File

@ -437,11 +437,13 @@ class SparseMatrix
template<typename InputIterators>
void setFromTriplets(const InputIterators& begin, const InputIterators& end);
template<typename DupFunctor, typename InputIterators>
void setFromTriplets(const InputIterators& begin, const InputIterators& end);
template<typename InputIterators,typename DupFunctor>
void setFromTriplets(const InputIterators& begin, const InputIterators& end, DupFunctor dup_func);
void sumupDuplicates() { collapseDuplicates(internal::scalar_sum_op<Scalar>()); }
template<typename DupFunctor>
void sumupDuplicates();
void collapseDuplicates(DupFunctor dup_func = DupFunctor());
//---
@ -894,9 +896,8 @@ private:
namespace internal {
template<typename InputIterator, typename SparseMatrixType, typename DupFunctor>
void set_from_triplets(const InputIterator& begin, const InputIterator& end, SparseMatrixType& mat, int Options = 0)
void set_from_triplets(const InputIterator& begin, const InputIterator& end, SparseMatrixType& mat, DupFunctor dup_func)
{
EIGEN_UNUSED_VARIABLE(Options);
enum { IsRowMajor = SparseMatrixType::IsRowMajor };
typedef typename SparseMatrixType::Scalar Scalar;
typedef typename SparseMatrixType::StorageIndex StorageIndex;
@ -919,7 +920,7 @@ void set_from_triplets(const InputIterator& begin, const InputIterator& end, Spa
trMat.insertBackUncompressed(it->row(),it->col()) = it->value();
// pass 3:
trMat.template sumupDuplicates<DupFunctor>();
trMat.collapseDuplicates(dup_func);
}
// pass 4: transposed copy -> implicit sorting
@ -970,25 +971,29 @@ template<typename Scalar, int _Options, typename _Index>
template<typename InputIterators>
void SparseMatrix<Scalar,_Options,_Index>::setFromTriplets(const InputIterators& begin, const InputIterators& end)
{
internal::set_from_triplets<InputIterators, SparseMatrix<Scalar,_Options,_Index>, internal::scalar_sum_op<Scalar> >(begin, end, *this);
internal::set_from_triplets<InputIterators, SparseMatrix<Scalar,_Options,_Index> >(begin, end, *this, internal::scalar_sum_op<Scalar>());
}
/** The same as setFromTriplets but when duplicates are met the functor \a DupFunctor is applied:
/** The same as setFromTriplets but when duplicates are met the functor \a dup_func is applied:
* \code
* value = DupFunctor()(OldValue, NewValue)
* value = dup_func(OldValue, NewValue)
* \endcode
*/
* Here is a C++11 example keeping the latest entry only:
* \code
* mat.setFromTriplets(triplets.begin(), triplets.end(), [] (const Scalar&,const Scalar &b) { return b; });
* \endcode
*/
template<typename Scalar, int _Options, typename _Index>
template<typename DupFunctor, typename InputIterators>
void SparseMatrix<Scalar,_Options,_Index>::setFromTriplets(const InputIterators& begin, const InputIterators& end)
template<typename InputIterators,typename DupFunctor>
void SparseMatrix<Scalar,_Options,_Index>::setFromTriplets(const InputIterators& begin, const InputIterators& end, DupFunctor dup_func)
{
internal::set_from_triplets<InputIterators, SparseMatrix<Scalar,_Options,_Index>, DupFunctor>(begin, end, *this);
internal::set_from_triplets<InputIterators, SparseMatrix<Scalar,_Options,_Index>, DupFunctor>(begin, end, *this, dup_func);
}
/** \internal */
template<typename Scalar, int _Options, typename _Index>
template<typename DupFunctor>
void SparseMatrix<Scalar,_Options,_Index>::sumupDuplicates()
void SparseMatrix<Scalar,_Options,_Index>::collapseDuplicates(DupFunctor dup_func)
{
eigen_assert(!isCompressed());
// TODO, in practice we should be able to use m_innerNonZeros for that task
@ -1006,7 +1011,7 @@ void SparseMatrix<Scalar,_Options,_Index>::sumupDuplicates()
if(wi(i)>=start)
{
// we already meet this entry => accumulate it
m_data.value(wi(i)) = DupFunctor()(m_data.value(wi(i)), m_data.value(k));
m_data.value(wi(i)) = dup_func(m_data.value(wi(i)), m_data.value(k));
}
else
{

View File

@ -258,19 +258,33 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
std::vector<TripletType> triplets;
Index ntriplets = rows*cols;
triplets.reserve(ntriplets);
DenseMatrix refMat(rows,cols);
refMat.setZero();
DenseMatrix refMat_sum = DenseMatrix::Zero(rows,cols);
DenseMatrix refMat_prod = DenseMatrix::Zero(rows,cols);
DenseMatrix refMat_last = DenseMatrix::Zero(rows,cols);
for(Index i=0;i<ntriplets;++i)
{
StorageIndex r = internal::random<StorageIndex>(0,StorageIndex(rows-1));
StorageIndex c = internal::random<StorageIndex>(0,StorageIndex(cols-1));
Scalar v = internal::random<Scalar>();
triplets.push_back(TripletType(r,c,v));
refMat(r,c) += v;
refMat_sum(r,c) += v;
if(std::abs(refMat_prod(r,c))==0)
refMat_prod(r,c) = v;
else
refMat_prod(r,c) *= v;
refMat_last(r,c) = v;
}
SparseMatrixType m(rows,cols);
m.setFromTriplets(triplets.begin(), triplets.end());
VERIFY_IS_APPROX(m, refMat);
VERIFY_IS_APPROX(m, refMat_sum);
m.setFromTriplets(triplets.begin(), triplets.end(), std::multiplies<Scalar>());
VERIFY_IS_APPROX(m, refMat_prod);
#if (defined(__cplusplus) && __cplusplus >= 201103L)
m.setFromTriplets(triplets.begin(), triplets.end(), [] (Scalar,Scalar b) { return b; });
VERIFY_IS_APPROX(m, refMat_last);
#endif
}
// test Map