evaluate 1D sparse expressions into SparseVector and make the sparse operator<< and dot honor nested types

This commit is contained in:
Gael Guennebaud 2011-12-22 14:01:06 +01:00
parent 7f04845023
commit 2c03e6fccc
5 changed files with 56 additions and 16 deletions

View File

@ -62,8 +62,16 @@ SparseMatrixBase<Derived>::dot(const SparseMatrixBase<OtherDerived>& other) cons
eigen_assert(size() == other.size());
typename Derived::InnerIterator i(derived(),0);
typename OtherDerived::InnerIterator j(other.derived(),0);
typedef typename Derived::Nested Nested;
typedef typename OtherDerived::Nested OtherNested;
typedef typename internal::remove_all<Nested>::type NestedCleaned;
typedef typename internal::remove_all<OtherNested>::type OtherNestedCleaned;
const Nested nthis(derived());
const OtherNested nother(other.derived());
typename NestedCleaned::InnerIterator i(nthis,0);
typename OtherNestedCleaned::InnerIterator j(nother,0);
Scalar res(0);
while (i && j)
{

View File

@ -274,12 +274,16 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m)
{
typedef typename Derived::Nested Nested;
typedef typename internal::remove_all<Nested>::type NestedCleaned;
if (Flags&RowMajorBit)
{
for (Index row=0; row<m.outerSize(); ++row)
const Nested nm(m.derived());
for (Index row=0; row<nm.outerSize(); ++row)
{
Index col = 0;
for (typename Derived::InnerIterator it(m.derived(), row); it; ++it)
for (typename NestedCleaned::InnerIterator it(nm.derived(), row); it; ++it)
{
for ( ; col<it.index(); ++col)
s << "0 ";
@ -293,9 +297,10 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
}
else
{
const Nested nm(m.derived());
if (m.cols() == 1) {
Index row = 0;
for (typename Derived::InnerIterator it(m.derived(), 0); it; ++it)
for (typename NestedCleaned::InnerIterator it(nm.derived(), 0); it; ++it)
{
for ( ; row<it.index(); ++row)
s << "0" << std::endl;
@ -307,7 +312,7 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
}
else
{
SparseMatrix<Scalar, RowMajorBit> trans = m.derived();
SparseMatrix<Scalar, RowMajorBit> trans = m;
s << static_cast<const SparseMatrixBase<SparseMatrix<Scalar, RowMajorBit> >&>(trans);
}
}

View File

@ -103,17 +103,39 @@ template<typename Lhs, typename Rhs, int InnerSize = internal::traits<Lhs>::Cols
namespace internal {
template<typename T> struct eval<T,Sparse>
{
typedef typename traits<T>::Scalar _Scalar;
enum {
_Flags = traits<T>::Flags
};
template<typename T,int Rows,int Cols> struct sparse_eval;
template<typename T> struct eval<T,Sparse>
: public sparse_eval<T, traits<T>::RowsAtCompileTime,traits<T>::ColsAtCompileTime>
{};
template<typename T,int Cols> struct sparse_eval<T,1,Cols> {
typedef typename traits<T>::Scalar _Scalar;
enum { _Flags = traits<T>::Flags| RowMajorBit };
public:
typedef SparseVector<_Scalar, _Flags> type;
};
template<typename T,int Rows> struct sparse_eval<T,Rows,1> {
typedef typename traits<T>::Scalar _Scalar;
enum { _Flags = traits<T>::Flags & (~RowMajorBit) };
public:
typedef SparseVector<_Scalar, _Flags> type;
};
template<typename T,int Rows,int Cols> struct sparse_eval {
typedef typename traits<T>::Scalar _Scalar;
enum { _Flags = traits<T>::Flags };
public:
typedef SparseMatrix<_Scalar, _Flags> type;
};
template<typename T> struct sparse_eval<T,1,1> {
typedef typename traits<T>::Scalar _Scalar;
public:
typedef Matrix<_Scalar, 1, 1> type;
};
template<typename T> struct plain_matrix_type<T,Sparse>
{
typedef typename traits<T>::Scalar _Scalar;

View File

@ -47,7 +47,7 @@ struct traits<SparseVector<_Scalar, _Options, _Index> >
typedef Sparse StorageKind;
typedef MatrixXpr XprKind;
enum {
IsColVector = _Options & RowMajorBit ? 0 : 1,
IsColVector = (_Options & RowMajorBit) ? 0 : 1,
RowsAtCompileTime = IsColVector ? Dynamic : 1,
ColsAtCompileTime = IsColVector ? 1 : Dynamic,
@ -320,7 +320,7 @@ protected:
const bool needToTranspose = (Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit);
if(needToTranspose)
{
Index size = other.innerSize();
Index size = other.size();
Index nnz = other.nonZeros();
resize(size);
reserve(nnz);

View File

@ -34,9 +34,9 @@ template<typename Scalar> void sparse_vector(int rows, int cols)
typedef SparseMatrix<Scalar> SparseMatrixType;
Scalar eps = 1e-6;
SparseMatrixType m1(rows,cols);
SparseMatrixType m1(rows,rows);
SparseVectorType v1(rows), v2(rows), v3(rows);
DenseMatrix refM1 = DenseMatrix::Zero(rows, cols);
DenseMatrix refM1 = DenseMatrix::Zero(rows, rows);
DenseVector refV1 = DenseVector::Random(rows),
refV2 = DenseVector::Random(rows),
refV3 = DenseVector::Random(rows);
@ -86,6 +86,11 @@ template<typename Scalar> void sparse_vector(int rows, int cols)
VERIFY_IS_APPROX(v1.dot(v2), refV1.dot(refV2));
VERIFY_IS_APPROX(v1.dot(refV2), refV1.dot(refV2));
VERIFY_IS_APPROX(v1.dot(m1*v2), refV1.dot(refM1*refV2));
int i = internal::random<int>(0,rows-1);
VERIFY_IS_APPROX(v1.dot(m1.col(i)), refV1.dot(refM1.col(i)));
VERIFY_IS_APPROX(v1.squaredNorm(), refV1.squaredNorm());
}