mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-11-27 06:30:28 +08:00
Make KroneckerProductSparse inherit EigenBase instead of SparseMatrixBase, for it does not provide an InnerIterator.
This commit is contained in:
parent
204a09cb82
commit
8321b7ae74
@ -282,8 +282,6 @@ struct stem_function
|
||||
};
|
||||
}
|
||||
|
||||
// KroneckerProduct module
|
||||
template<typename Lhs, typename Rhs> class KroneckerProductSparse;
|
||||
|
||||
#ifdef EIGEN2_SUPPORT
|
||||
template<typename ExpressionType> class Cwise;
|
||||
|
@ -259,9 +259,6 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
|
||||
template<typename Lhs, typename Rhs>
|
||||
inline Derived& operator=(const SparseSparseProduct<Lhs,Rhs>& product);
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
inline Derived& operator=(const KroneckerProductSparse<Lhs,Rhs>& product);
|
||||
|
||||
friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m)
|
||||
{
|
||||
typedef typename Derived::Nested Nested;
|
||||
|
@ -2,7 +2,6 @@
|
||||
#define EIGEN_KRONECKER_PRODUCT_MODULE_H
|
||||
|
||||
#include "../../Eigen/Core"
|
||||
#include "../../Eigen/SparseCore"
|
||||
|
||||
#include "../../Eigen/src/Core/util/DisableStupidWarnings.h"
|
||||
|
||||
|
@ -12,12 +12,10 @@
|
||||
#ifndef KRONECKER_TENSOR_PRODUCT_H
|
||||
#define KRONECKER_TENSOR_PRODUCT_H
|
||||
|
||||
#define EIGEN_SIZE_PRODUCT(a,b) (!((int)a && (int)b) ? 0 \
|
||||
: ((int)a == Dynamic || (int)b == Dynamic) ? Dynamic \
|
||||
: (int)a * (int)b)
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
template<typename Scalar, int Options, typename Index> class SparseMatrix;
|
||||
|
||||
/*!
|
||||
* \brief Kronecker tensor product helper class for dense matrices
|
||||
*
|
||||
@ -79,12 +77,12 @@ class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
|
||||
* \tparam Rhs Type of the rignt-hand side, a matrix expression.
|
||||
*/
|
||||
template<typename Lhs, typename Rhs>
|
||||
class KroneckerProductSparse : public SparseMatrixBase<KroneckerProductSparse<Lhs,Rhs> >
|
||||
class KroneckerProductSparse : public EigenBase<KroneckerProductSparse<Lhs,Rhs> >
|
||||
{
|
||||
public:
|
||||
typedef SparseMatrixBase<KroneckerProductSparse> Base;
|
||||
EIGEN_DENSE_PUBLIC_INTERFACE(KroneckerProductSparse)
|
||||
private:
|
||||
typedef typename internal::traits<KroneckerProductSparse>::Index Index;
|
||||
|
||||
public:
|
||||
/*! \brief Constructor. */
|
||||
KroneckerProductSparse(const Lhs& A, const Rhs& B)
|
||||
: m_A(A), m_B(B)
|
||||
@ -96,6 +94,14 @@ class KroneckerProductSparse : public SparseMatrixBase<KroneckerProductSparse<Lh
|
||||
inline Index rows() const { return m_A.rows() * m_B.rows(); }
|
||||
inline Index cols() const { return m_A.cols() * m_B.cols(); }
|
||||
|
||||
template<typename Scalar, int Options, typename Index>
|
||||
operator SparseMatrix<Scalar, Options, Index>()
|
||||
{
|
||||
SparseMatrix<Scalar, Options, Index> result;
|
||||
evalTo(result.derived());
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
typename Lhs::Nested m_A;
|
||||
typename Rhs::Nested m_B;
|
||||
@ -151,10 +157,10 @@ struct traits<KroneckerProduct<_Lhs,_Rhs> >
|
||||
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
|
||||
|
||||
enum {
|
||||
Rows = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
|
||||
Cols = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
|
||||
MaxRows = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
|
||||
MaxCols = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
|
||||
Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
|
||||
Cols = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
|
||||
MaxRows = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
|
||||
MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret,
|
||||
CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost
|
||||
};
|
||||
|
||||
@ -168,17 +174,17 @@ struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
|
||||
typedef typename remove_all<_Lhs>::type Lhs;
|
||||
typedef typename remove_all<_Rhs>::type Rhs;
|
||||
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
|
||||
typedef Sparse StorageKind;
|
||||
typedef typename promote_storage_type<typename traits<Lhs>::StorageKind, typename traits<Rhs>::StorageKind>::ret StorageKind;
|
||||
typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index;
|
||||
|
||||
enum {
|
||||
LhsFlags = Lhs::Flags,
|
||||
RhsFlags = Rhs::Flags,
|
||||
|
||||
RowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
|
||||
ColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
|
||||
MaxRowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
|
||||
MaxColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
|
||||
RowsAtCompileTime = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
|
||||
ColsAtCompileTime = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
|
||||
MaxRowsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
|
||||
MaxColsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret,
|
||||
|
||||
EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit),
|
||||
RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
|
||||
@ -233,14 +239,6 @@ KroneckerProductSparse<A,B> kroneckerProduct(const EigenBase<A>& a, const EigenB
|
||||
return KroneckerProductSparse<A,B>(a.derived(), b.derived());
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
template<typename Lhs, typename Rhs>
|
||||
Derived& SparseMatrixBase<Derived>::operator=(const KroneckerProductSparse<Lhs,Rhs>& product)
|
||||
{
|
||||
product.evalTo(derived());
|
||||
return derived();
|
||||
}
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // KRONECKER_TENSOR_PRODUCT_H
|
||||
|
Loading…
Reference in New Issue
Block a user