mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Fix bug #838: fix dense * sparse and sparse * dense outer products
This commit is contained in:
parent
df604e4f49
commit
c0f76ce2cf
@ -19,7 +19,10 @@ template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductRet
|
|||||||
|
|
||||||
template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1>
|
template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1>
|
||||||
{
|
{
|
||||||
typedef SparseDenseOuterProduct<Lhs,Rhs,false> Type;
|
typedef typename internal::conditional<
|
||||||
|
Lhs::IsRowMajor,
|
||||||
|
SparseDenseOuterProduct<Rhs,Lhs,true>,
|
||||||
|
SparseDenseOuterProduct<Lhs,Rhs,false> >::type Type;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType
|
template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType
|
||||||
@ -29,7 +32,10 @@ template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductRet
|
|||||||
|
|
||||||
template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1>
|
template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1>
|
||||||
{
|
{
|
||||||
typedef SparseDenseOuterProduct<Rhs,Lhs,true> Type;
|
typedef typename internal::conditional<
|
||||||
|
Rhs::IsRowMajor,
|
||||||
|
SparseDenseOuterProduct<Rhs,Lhs,true>,
|
||||||
|
SparseDenseOuterProduct<Lhs,Rhs,false> >::type Type;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
@ -114,17 +120,30 @@ class SparseDenseOuterProduct<Lhs,Rhs,Transpose>::InnerIterator : public _LhsNes
|
|||||||
typedef typename SparseDenseOuterProduct::Index Index;
|
typedef typename SparseDenseOuterProduct::Index Index;
|
||||||
public:
|
public:
|
||||||
EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer)
|
EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer)
|
||||||
: Base(prod.lhs(), 0), m_outer(outer), m_factor(prod.rhs().coeff(outer))
|
: Base(prod.lhs(), 0), m_outer(outer), m_factor(get(prod.rhs(), outer, typename internal::traits<Rhs>::StorageKind() ))
|
||||||
{
|
{ }
|
||||||
}
|
|
||||||
|
|
||||||
inline Index outer() const { return m_outer; }
|
inline Index outer() const { return m_outer; }
|
||||||
inline Index row() const { return Transpose ? Base::row() : m_outer; }
|
inline Index row() const { return Transpose ? m_outer : Base::index(); }
|
||||||
inline Index col() const { return Transpose ? m_outer : Base::row(); }
|
inline Index col() const { return Transpose ? Base::index() : m_outer; }
|
||||||
|
|
||||||
inline Scalar value() const { return Base::value() * m_factor; }
|
inline Scalar value() const { return Base::value() * m_factor; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
static Scalar get(const _RhsNested &rhs, Index outer, Dense = Dense())
|
||||||
|
{
|
||||||
|
return rhs.coeff(outer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Scalar get(const _RhsNested &rhs, Index outer, Sparse = Sparse())
|
||||||
|
{
|
||||||
|
typename Traits::_RhsNested::InnerIterator it(rhs, outer);
|
||||||
|
if (it && it.index()==0)
|
||||||
|
return it.value();
|
||||||
|
|
||||||
|
return Scalar(0);
|
||||||
|
}
|
||||||
|
|
||||||
Index m_outer;
|
Index m_outer;
|
||||||
Scalar m_factor;
|
Scalar m_factor;
|
||||||
};
|
};
|
||||||
|
@ -9,32 +9,6 @@
|
|||||||
|
|
||||||
#include "sparse.h"
|
#include "sparse.h"
|
||||||
|
|
||||||
template<typename SparseMatrixType, typename DenseMatrix, bool IsRowMajor=SparseMatrixType::IsRowMajor> struct test_outer;
|
|
||||||
|
|
||||||
template<typename SparseMatrixType, typename DenseMatrix> struct test_outer<SparseMatrixType,DenseMatrix,false> {
|
|
||||||
static void run(SparseMatrixType& m2, SparseMatrixType& m4, DenseMatrix& refMat2, DenseMatrix& refMat4) {
|
|
||||||
typedef typename SparseMatrixType::Index Index;
|
|
||||||
Index c = internal::random<Index>(0,m2.cols()-1);
|
|
||||||
Index c1 = internal::random<Index>(0,m2.cols()-1);
|
|
||||||
VERIFY_IS_APPROX(m4=m2.col(c)*refMat2.col(c1).transpose(), refMat4=refMat2.col(c)*refMat2.col(c1).transpose());
|
|
||||||
VERIFY_IS_APPROX(m4=refMat2.col(c1)*m2.col(c).transpose(), refMat4=refMat2.col(c1)*refMat2.col(c).transpose());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename SparseMatrixType, typename DenseMatrix> struct test_outer<SparseMatrixType,DenseMatrix,true> {
|
|
||||||
static void run(SparseMatrixType& m2, SparseMatrixType& m4, DenseMatrix& refMat2, DenseMatrix& refMat4) {
|
|
||||||
typedef typename SparseMatrixType::Index Index;
|
|
||||||
Index r = internal::random<Index>(0,m2.rows()-1);
|
|
||||||
Index c1 = internal::random<Index>(0,m2.cols()-1);
|
|
||||||
VERIFY_IS_APPROX(m4=m2.row(r).transpose()*refMat2.col(c1).transpose(), refMat4=refMat2.row(r).transpose()*refMat2.col(c1).transpose());
|
|
||||||
VERIFY_IS_APPROX(m4=refMat2.col(c1)*m2.row(r), refMat4=refMat2.col(c1)*refMat2.row(r));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// (m2,m4,refMat2,refMat4,dv1);
|
|
||||||
// VERIFY_IS_APPROX(m4=m2.innerVector(c)*dv1.transpose(), refMat4=refMat2.colVector(c)*dv1.transpose());
|
|
||||||
// VERIFY_IS_APPROX(m4=dv1*mcm.col(c).transpose(), refMat4=dv1*refMat2.col(c).transpose());
|
|
||||||
|
|
||||||
template<typename SparseMatrixType> void sparse_product()
|
template<typename SparseMatrixType> void sparse_product()
|
||||||
{
|
{
|
||||||
typedef typename SparseMatrixType::Index Index;
|
typedef typename SparseMatrixType::Index Index;
|
||||||
@ -119,7 +93,30 @@ template<typename SparseMatrixType> void sparse_product()
|
|||||||
VERIFY_IS_APPROX(dm4=refMat2t.transpose()*m3t.transpose(), refMat4=refMat2t.transpose()*refMat3t.transpose());
|
VERIFY_IS_APPROX(dm4=refMat2t.transpose()*m3t.transpose(), refMat4=refMat2t.transpose()*refMat3t.transpose());
|
||||||
|
|
||||||
// sparse * dense and dense * sparse outer product
|
// sparse * dense and dense * sparse outer product
|
||||||
test_outer<SparseMatrixType,DenseMatrix>::run(m2,m4,refMat2,refMat4);
|
{
|
||||||
|
Index c = internal::random<Index>(0,depth-1);
|
||||||
|
Index r = internal::random<Index>(0,rows-1);
|
||||||
|
Index c1 = internal::random<Index>(0,cols-1);
|
||||||
|
Index r1 = internal::random<Index>(0,depth-1);
|
||||||
|
|
||||||
|
VERIFY_IS_APPROX( m4=m2.col(c)*refMat3.col(c1).transpose(), refMat4=refMat2.col(c)*refMat3.col(c1).transpose());
|
||||||
|
VERIFY_IS_APPROX(dm4=m2.col(c)*refMat3.col(c1).transpose(), refMat4=refMat2.col(c)*refMat3.col(c1).transpose());
|
||||||
|
|
||||||
|
VERIFY_IS_APPROX(m4=refMat3.col(c1)*m2.col(c).transpose(), refMat4=refMat3.col(c1)*refMat2.col(c).transpose());
|
||||||
|
VERIFY_IS_APPROX(dm4=refMat3.col(c1)*m2.col(c).transpose(), refMat4=refMat3.col(c1)*refMat2.col(c).transpose());
|
||||||
|
|
||||||
|
VERIFY_IS_APPROX( m4=refMat3.row(r1).transpose()*m2.col(c).transpose(), refMat4=refMat3.row(r1).transpose()*refMat2.col(c).transpose());
|
||||||
|
VERIFY_IS_APPROX(dm4=refMat3.row(r1).transpose()*m2.col(c).transpose(), refMat4=refMat3.row(r1).transpose()*refMat2.col(c).transpose());
|
||||||
|
|
||||||
|
VERIFY_IS_APPROX( m4=m2.row(r).transpose()*refMat3.col(c1).transpose(), refMat4=refMat2.row(r).transpose()*refMat3.col(c1).transpose());
|
||||||
|
VERIFY_IS_APPROX(dm4=m2.row(r).transpose()*refMat3.col(c1).transpose(), refMat4=refMat2.row(r).transpose()*refMat3.col(c1).transpose());
|
||||||
|
|
||||||
|
VERIFY_IS_APPROX( m4=refMat3.col(c1)*m2.row(r), refMat4=refMat3.col(c1)*refMat2.row(r));
|
||||||
|
VERIFY_IS_APPROX(dm4=refMat3.col(c1)*m2.row(r), refMat4=refMat3.col(c1)*refMat2.row(r));
|
||||||
|
|
||||||
|
VERIFY_IS_APPROX( m4=refMat3.row(r1).transpose()*m2.row(r), refMat4=refMat3.row(r1).transpose()*refMat2.row(r));
|
||||||
|
VERIFY_IS_APPROX(dm4=refMat3.row(r1).transpose()*m2.row(r), refMat4=refMat3.row(r1).transpose()*refMat2.row(r));
|
||||||
|
}
|
||||||
|
|
||||||
VERIFY_IS_APPROX(m6=m6*m6, refMat6=refMat6*refMat6);
|
VERIFY_IS_APPROX(m6=m6*m6, refMat6=refMat6*refMat6);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user