Make KroneckerProductSparse inherit EigenBase instead of SparseMatrixBase, for it does not provide an InnerIterator.

This commit is contained in:
Chen-Pang He 2012-10-25 02:09:48 +08:00
parent 204a09cb82
commit 8321b7ae74
4 changed files with 23 additions and 31 deletions

View File

@ -282,8 +282,6 @@ struct stem_function
};
}
// KroneckerProduct module
template<typename Lhs, typename Rhs> class KroneckerProductSparse;
#ifdef EIGEN2_SUPPORT
template<typename ExpressionType> class Cwise;

View File

@ -259,9 +259,6 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
template<typename Lhs, typename Rhs>
inline Derived& operator=(const SparseSparseProduct<Lhs,Rhs>& product);
template<typename Lhs, typename Rhs>
inline Derived& operator=(const KroneckerProductSparse<Lhs,Rhs>& product);
friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m)
{
typedef typename Derived::Nested Nested;

View File

@ -2,7 +2,6 @@
#define EIGEN_KRONECKER_PRODUCT_MODULE_H
#include "../../Eigen/Core"
#include "../../Eigen/SparseCore"
#include "../../Eigen/src/Core/util/DisableStupidWarnings.h"

View File

@ -12,12 +12,10 @@
#ifndef KRONECKER_TENSOR_PRODUCT_H
#define KRONECKER_TENSOR_PRODUCT_H
#define EIGEN_SIZE_PRODUCT(a,b) (!((int)a && (int)b) ? 0 \
: ((int)a == Dynamic || (int)b == Dynamic) ? Dynamic \
: (int)a * (int)b)
namespace Eigen {
template<typename Scalar, int Options, typename Index> class SparseMatrix;
/*!
* \brief Kronecker tensor product helper class for dense matrices
*
@ -79,12 +77,12 @@ class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
* \tparam Rhs Type of the rignt-hand side, a matrix expression.
*/
template<typename Lhs, typename Rhs>
class KroneckerProductSparse : public SparseMatrixBase<KroneckerProductSparse<Lhs,Rhs> >
class KroneckerProductSparse : public EigenBase<KroneckerProductSparse<Lhs,Rhs> >
{
public:
typedef SparseMatrixBase<KroneckerProductSparse> Base;
EIGEN_DENSE_PUBLIC_INTERFACE(KroneckerProductSparse)
private:
typedef typename internal::traits<KroneckerProductSparse>::Index Index;
public:
/*! \brief Constructor. */
KroneckerProductSparse(const Lhs& A, const Rhs& B)
: m_A(A), m_B(B)
@ -96,6 +94,14 @@ class KroneckerProductSparse : public SparseMatrixBase<KroneckerProductSparse<Lh
inline Index rows() const { return m_A.rows() * m_B.rows(); }
inline Index cols() const { return m_A.cols() * m_B.cols(); }
template<typename Scalar, int Options, typename Index>
operator SparseMatrix<Scalar, Options, Index>()
{
SparseMatrix<Scalar, Options, Index> result;
evalTo(result.derived());
return result;
}
private:
typename Lhs::Nested m_A;
typename Rhs::Nested m_B;
@ -151,10 +157,10 @@ struct traits<KroneckerProduct<_Lhs,_Rhs> >
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
enum {
Rows = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
Cols = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
MaxRows = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
MaxCols = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
Cols = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
MaxRows = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret,
CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost
};
@ -168,17 +174,17 @@ struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
typedef typename remove_all<_Lhs>::type Lhs;
typedef typename remove_all<_Rhs>::type Rhs;
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
typedef Sparse StorageKind;
typedef typename promote_storage_type<typename traits<Lhs>::StorageKind, typename traits<Rhs>::StorageKind>::ret StorageKind;
typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index;
enum {
LhsFlags = Lhs::Flags,
RhsFlags = Rhs::Flags,
RowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
ColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
MaxRowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
MaxColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
RowsAtCompileTime = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
ColsAtCompileTime = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
MaxRowsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
MaxColsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret,
EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit),
RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
@ -233,14 +239,6 @@ KroneckerProductSparse<A,B> kroneckerProduct(const EigenBase<A>& a, const EigenB
return KroneckerProductSparse<A,B>(a.derived(), b.derived());
}
template<typename Derived>
template<typename Lhs, typename Rhs>
Derived& SparseMatrixBase<Derived>::operator=(const KroneckerProductSparse<Lhs,Rhs>& product)
{
product.evalTo(derived());
return derived();
}
} // end namespace Eigen
#endif // KRONECKER_TENSOR_PRODUCT_H