fix performance issue with SPMV

This commit is contained in:
Gael Guennebaud 2011-11-11 06:04:31 +01:00
parent 9d82a7e204
commit c110abb7d2

View File

@ -166,42 +166,29 @@ class SparseTimeDenseProduct
typedef typename internal::remove_all<Lhs>::type _Lhs; typedef typename internal::remove_all<Lhs>::type _Lhs;
typedef typename internal::remove_all<Rhs>::type _Rhs; typedef typename internal::remove_all<Rhs>::type _Rhs;
typedef typename _Lhs::InnerIterator LhsInnerIterator; typedef typename _Lhs::InnerIterator LhsInnerIterator;
enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit }; enum {
LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit,
RhsIsVector = Rhs::ColsAtCompileTime==1
};
Index j=0; Index j=0;
//#pragma omp parallel for private(j) schedule(static,4)
for(j=0; j<m_lhs.outerSize(); ++j) for(j=0; j<m_lhs.outerSize(); ++j)
{ {
//kernel(dest,alpha,j);
typename Rhs::Scalar rhs_j = alpha * m_rhs.coeff(LhsIsRowMajor ? 0 : j,0); typename Rhs::Scalar rhs_j = alpha * m_rhs.coeff(LhsIsRowMajor ? 0 : j,0);
typename Dest::RowXpr dest_j(dest.row(LhsIsRowMajor ? j : 0)); typename Dest::RowXpr dest_j(dest.row(LhsIsRowMajor ? j : 0));
typename Dest::Scalar tmp(0);
for(LhsInnerIterator it(m_lhs,j); it ;++it) for(LhsInnerIterator it(m_lhs,j); it ;++it)
{ {
if(LhsIsRowMajor) dest_j += (alpha*it.value()) * m_rhs.row(it.index()); if(LhsIsRowMajor && RhsIsVector) tmp += (it.value()) * m_rhs.coeff(it.index());
else if(Rhs::ColsAtCompileTime==1) dest.coeffRef(it.index()) += it.value() * rhs_j; else if(LhsIsRowMajor) dest_j += (alpha*it.value()) * m_rhs.row(it.index());
else dest.row(it.index()) += (alpha*it.value()) * m_rhs.row(j); else if(RhsIsVector) dest.coeffRef(it.index()) += it.value() * rhs_j;
else dest.row(it.index()) += (alpha*it.value()) * m_rhs.row(j);
} }
if(LhsIsRowMajor && RhsIsVector)
dest.coeffRef(LhsIsRowMajor ? j : 0) = alpha * tmp;
} }
} }
private: private:
template<typename Dest>
EIGEN_DONT_INLINE void kernel(Dest& dest, Scalar alpha, int j) const
{
typedef typename internal::remove_all<Lhs>::type _Lhs;
typedef typename internal::remove_all<Rhs>::type _Rhs;
typedef typename _Lhs::InnerIterator LhsInnerIterator;
enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit };
{
typename Rhs::Scalar rhs_j = alpha * m_rhs.coeff(LhsIsRowMajor ? 0 : j,0);
typename Dest::RowXpr dest_j(dest.row(LhsIsRowMajor ? j : 0));
for(LhsInnerIterator it(m_lhs,j); it ;++it)
{
if(LhsIsRowMajor) dest_j += (alpha*it.value()) * m_rhs.row(it.index());
else if(Rhs::ColsAtCompileTime==1) dest.coeffRef(it.index()) += it.value() * rhs_j;
else dest.row(it.index()) += (alpha*it.value()) * m_rhs.row(j);
}
}
}
SparseTimeDenseProduct& operator=(const SparseTimeDenseProduct&); SparseTimeDenseProduct& operator=(const SparseTimeDenseProduct&);
}; };