Add support for dense.cwiseProduct(sparse)

This also fixes a regression regarding (dense*sparse).diagonal()
This commit is contained in:
Gael Guennebaud 2015-11-04 17:42:07 +01:00
parent fd074be1a0
commit c030925a66
6 changed files with 26 additions and 15 deletions

View File

@ -440,6 +440,15 @@ template<typename Derived> class MatrixBase
template<typename OtherScalar> template<typename OtherScalar>
void applyOnTheRight(Index p, Index q, const JacobiRotation<OtherScalar>& j); void applyOnTheRight(Index p, Index q, const JacobiRotation<OtherScalar>& j);
///////// SparseCore module /////////
template<typename OtherDerived>
EIGEN_STRONG_INLINE const typename SparseMatrixBase<OtherDerived>::template CwiseProductDenseReturnType<Derived>::Type
cwiseProduct(const SparseMatrixBase<OtherDerived> &other) const
{
return other.cwiseProduct(derived());
}
///////// MatrixFunctions module ///////// ///////// MatrixFunctions module /////////
typedef typename internal::stem_function<Scalar>::type StemFunction; typedef typename internal::stem_function<Scalar>::type StemFunction;

View File

@ -235,6 +235,9 @@ template<typename Scalar> class Rotation2D;
template<typename Scalar> class AngleAxis; template<typename Scalar> class AngleAxis;
template<typename Scalar,int Dim> class Translation; template<typename Scalar,int Dim> class Translation;
// Sparse module:
template<typename Derived> class SparseMatrixBase;
#ifdef EIGEN2_SUPPORT #ifdef EIGEN2_SUPPORT
template<typename Derived, int _Dim> class eigen2_RotationBase; template<typename Derived, int _Dim> class eigen2_RotationBase;
template<typename Lhs, typename Rhs> class eigen2_Cross; template<typename Lhs, typename Rhs> class eigen2_Cross;

View File

@ -314,10 +314,10 @@ SparseMatrixBase<Derived>::operator+=(const SparseMatrixBase<OtherDerived>& othe
template<typename Derived> template<typename Derived>
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE EIGEN_STRONG_INLINE const typename SparseMatrixBase<Derived>::template CwiseProductDenseReturnType<OtherDerived>::Type
SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const
{ {
return EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE(derived(), other.derived()); return typename CwiseProductDenseReturnType<OtherDerived>::Type(derived(), other.derived());
} }
} // end namespace Eigen } // end namespace Eigen

View File

@ -317,20 +317,18 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
Derived& operator*=(const Scalar& other); Derived& operator*=(const Scalar& other);
Derived& operator/=(const Scalar& other); Derived& operator/=(const Scalar& other);
#define EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE \ template<typename OtherDerived> struct CwiseProductDenseReturnType {
CwiseBinaryOp< \ typedef CwiseBinaryOp<internal::scalar_product_op<typename internal::scalar_product_traits<
internal::scalar_product_op< \ typename internal::traits<Derived>::Scalar,
typename internal::scalar_product_traits< \ typename internal::traits<OtherDerived>::Scalar
typename internal::traits<Derived>::Scalar, \ >::ReturnType>,
typename internal::traits<OtherDerived>::Scalar \ const Derived,
>::ReturnType \ const OtherDerived
>, \ > Type;
const Derived, \ };
const OtherDerived \
>
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE EIGEN_STRONG_INLINE const typename CwiseProductDenseReturnType<OtherDerived>::Type
cwiseProduct(const MatrixBase<OtherDerived> &other) const; cwiseProduct(const MatrixBase<OtherDerived> &other) const;
// sparse * sparse // sparse * sparse

View File

@ -67,7 +67,6 @@ const int InnerRandomAccessPattern = 0x2 | CoherentAccessPattern;
const int OuterRandomAccessPattern = 0x4 | CoherentAccessPattern; const int OuterRandomAccessPattern = 0x4 | CoherentAccessPattern;
const int RandomAccessPattern = 0x8 | OuterRandomAccessPattern | InnerRandomAccessPattern; const int RandomAccessPattern = 0x8 | OuterRandomAccessPattern | InnerRandomAccessPattern;
template<typename Derived> class SparseMatrixBase;
template<typename _Scalar, int _Flags = 0, typename _Index = int> class SparseMatrix; template<typename _Scalar, int _Flags = 0, typename _Index = int> class SparseMatrix;
template<typename _Scalar, int _Flags = 0, typename _Index = int> class DynamicSparseMatrix; template<typename _Scalar, int _Flags = 0, typename _Index = int> class DynamicSparseMatrix;
template<typename _Scalar, int _Flags = 0, typename _Index = int> class SparseVector; template<typename _Scalar, int _Flags = 0, typename _Index = int> class SparseVector;

View File

@ -306,6 +306,8 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
refM4.setRandom(); refM4.setRandom();
// sparse cwise* dense // sparse cwise* dense
VERIFY_IS_APPROX(m3.cwiseProduct(refM4), refM3.cwiseProduct(refM4)); VERIFY_IS_APPROX(m3.cwiseProduct(refM4), refM3.cwiseProduct(refM4));
// dense cwise* sparse
VERIFY_IS_APPROX(refM4.cwiseProduct(m3), refM4.cwiseProduct(refM3));
// VERIFY_IS_APPROX(m3.cwise()/refM4, refM3.cwise()/refM4); // VERIFY_IS_APPROX(m3.cwise()/refM4, refM3.cwise()/refM4);
// test aliasing // test aliasing