Fix bug #838: fix dense * sparse and sparse * dense outer products

This commit is contained in:
Gael Guennebaud 2014-07-11 16:31:41 +02:00
parent 51e2e93019
commit 0cc67589d3

View File

@ -19,7 +19,10 @@ template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductRet
template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1> template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1>
{ {
typedef SparseDenseOuterProduct<Lhs,Rhs,false> Type; typedef typename internal::conditional<
Lhs::IsRowMajor,
SparseDenseOuterProduct<Rhs,Lhs,true>,
SparseDenseOuterProduct<Lhs,Rhs,false> >::type Type;
}; };
template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType
@ -29,7 +32,10 @@ template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductRet
template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1> template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1>
{ {
typedef SparseDenseOuterProduct<Rhs,Lhs,true> Type; typedef typename internal::conditional<
Rhs::IsRowMajor,
SparseDenseOuterProduct<Rhs,Lhs,true>,
SparseDenseOuterProduct<Lhs,Rhs,false> >::type Type;
}; };
namespace internal { namespace internal {
@ -114,17 +120,30 @@ class SparseDenseOuterProduct<Lhs,Rhs,Transpose>::InnerIterator : public _LhsNes
typedef typename SparseDenseOuterProduct::Index Index; typedef typename SparseDenseOuterProduct::Index Index;
public: public:
EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer) EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer)
: Base(prod.lhs(), 0), m_outer(outer), m_factor(prod.rhs().coeff(outer)) : Base(prod.lhs(), 0), m_outer(outer), m_factor(get(prod.rhs(), outer, typename internal::traits<Rhs>::StorageKind() ))
{ { }
}
inline Index outer() const { return m_outer; } inline Index outer() const { return m_outer; }
inline Index row() const { return Transpose ? Base::row() : m_outer; } inline Index row() const { return Transpose ? m_outer : Base::index(); }
inline Index col() const { return Transpose ? m_outer : Base::row(); } inline Index col() const { return Transpose ? Base::index() : m_outer; }
inline Scalar value() const { return Base::value() * m_factor; } inline Scalar value() const { return Base::value() * m_factor; }
protected: protected:
static Scalar get(const _RhsNested &rhs, Index outer, Dense = Dense())
{
return rhs.coeff(outer);
}
static Scalar get(const _RhsNested &rhs, Index outer, Sparse = Sparse())
{
typename Traits::_RhsNested::InnerIterator it(rhs, outer);
if (it && it.index()==0)
return it.value();
return Scalar(0);
}
Index m_outer; Index m_outer;
Scalar m_factor; Scalar m_factor;
}; };