Refactor TriangularView to handle both dense and sparse objects. Introduce a glu_shape<S1,S2> helper to assemble sparse/dense shapes with triagular/seladjoint views.

This commit is contained in:
Gael Guennebaud 2014-07-22 11:35:56 +02:00
parent 2a251ffab0
commit 6daa6a0d16
15 changed files with 520 additions and 284 deletions

View File

@ -54,12 +54,12 @@ struct Sparse {};
#include "src/SparseCore/SparseProduct.h"
#include "src/SparseCore/SparseDenseProduct.h"
#include "src/SparseCore/SparseSelfAdjointView.h"
#include "src/SparseCore/SparseTriangularView.h"
#include "src/SparseCore/TriangularSolver.h"
#ifndef EIGEN_TEST_EVALUATORS
#include "src/SparseCore/SparsePermutation.h"
#include "src/SparseCore/SparseFuzzy.h"
#include "src/SparseCore/SparseTriangularView.h"
#include "src/SparseCore/TriangularSolver.h"
#endif
#include "src/Core/util/ReenableStupidWarnings.h"

View File

@ -171,10 +171,10 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,CompleteUnrolling,1> {
*/
template<typename MatrixType, unsigned int Mode>
template<int Side, typename OtherDerived>
void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<OtherDerived>& _other) const
void TriangularViewImpl<MatrixType,Mode,Dense>::solveInPlace(const MatrixBase<OtherDerived>& _other) const
{
OtherDerived& other = _other.const_cast_derived();
eigen_assert( cols() == rows() && ((Side==OnTheLeft && cols() == other.rows()) || (Side==OnTheRight && cols() == other.cols())) );
eigen_assert( derived().cols() == derived().rows() && ((Side==OnTheLeft && derived().cols() == other.rows()) || (Side==OnTheRight && derived().cols() == other.cols())) );
eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit && OtherDerived::IsVectorAtCompileTime };
@ -183,7 +183,7 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<OtherDerived
OtherCopy otherCopy(other);
internal::triangular_solver_selector<MatrixType, typename internal::remove_reference<OtherCopy>::type,
Side, Mode>::run(nestedExpression(), otherCopy);
Side, Mode>::run(derived().nestedExpression(), otherCopy);
if (copy)
other = otherCopy;
@ -213,9 +213,9 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<OtherDerived
template<typename Derived, unsigned int Mode>
template<int Side, typename Other>
const internal::triangular_solve_retval<Side,TriangularView<Derived,Mode>,Other>
TriangularView<Derived,Mode>::solve(const MatrixBase<Other>& other) const
TriangularViewImpl<Derived,Mode,Dense>::solve(const MatrixBase<Other>& other) const
{
return internal::triangular_solve_retval<Side,TriangularView,Other>(*this, other.derived());
return internal::triangular_solve_retval<Side,TriangularViewType,Other>(derived(), other.derived());
}
namespace internal {

View File

@ -49,7 +49,7 @@ template<typename Derived> class TriangularBase : public EigenBase<Derived>
typedef typename internal::traits<Derived>::Scalar Scalar;
typedef typename internal::traits<Derived>::StorageKind StorageKind;
typedef typename internal::traits<Derived>::Index Index;
typedef typename internal::traits<Derived>::DenseMatrixType DenseMatrixType;
typedef typename internal::traits<Derived>::FullMatrixType DenseMatrixType;
typedef DenseMatrixType DenseType;
typedef Derived const& Nested;
@ -172,8 +172,8 @@ struct traits<TriangularView<MatrixType, _Mode> > : traits<MatrixType>
typedef typename nested<MatrixType>::type MatrixTypeNested;
typedef typename remove_reference<MatrixTypeNested>::type MatrixTypeNestedNonRef;
typedef typename remove_all<MatrixTypeNested>::type MatrixTypeNestedCleaned;
typedef typename MatrixType::PlainObject FullMatrixType;
typedef MatrixType ExpressionType;
typedef typename MatrixType::PlainObject DenseMatrixType;
enum {
Mode = _Mode,
Flags = (MatrixTypeNestedCleaned::Flags & (HereditaryBits | LvalueBit) & (~(PacketAccessBit | DirectAccessBit | LinearAccessBit))) | Mode
@ -185,22 +185,23 @@ struct traits<TriangularView<MatrixType, _Mode> > : traits<MatrixType>
};
}
#ifndef EIGEN_TEST_EVALUATORS
template<int Mode, bool LhsIsTriangular,
typename Lhs, bool LhsIsVector,
typename Rhs, bool RhsIsVector>
struct TriangularProduct;
#endif
template<typename _MatrixType, unsigned int _Mode, typename StorageKind> class TriangularViewImpl;
template<typename _MatrixType, unsigned int _Mode> class TriangularView
: public TriangularBase<TriangularView<_MatrixType, _Mode> >
: public TriangularViewImpl<_MatrixType, _Mode, typename internal::traits<_MatrixType>::StorageKind >
{
public:
typedef TriangularBase<TriangularView> Base;
typedef TriangularViewImpl<_MatrixType, _Mode, typename internal::traits<_MatrixType>::StorageKind > Base;
typedef typename internal::traits<TriangularView>::Scalar Scalar;
typedef _MatrixType MatrixType;
typedef typename internal::traits<TriangularView>::DenseMatrixType DenseMatrixType;
typedef DenseMatrixType PlainObject;
protected:
typedef typename internal::traits<TriangularView>::MatrixTypeNested MatrixTypeNested;
@ -210,8 +211,6 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
typedef typename internal::remove_all<typename MatrixType::ConjugateReturnType>::type MatrixConjugateReturnType;
public:
using Base::evalToLazy;
typedef typename internal::traits<TriangularView>::StorageKind StorageKind;
typedef typename internal::traits<TriangularView>::Index Index;
@ -228,110 +227,19 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
EIGEN_DEVICE_FUNC
inline TriangularView(const MatrixType& matrix) : m_matrix(matrix)
{}
using Base::operator=;
EIGEN_DEVICE_FUNC
inline Index rows() const { return m_matrix.rows(); }
EIGEN_DEVICE_FUNC
inline Index cols() const { return m_matrix.cols(); }
EIGEN_DEVICE_FUNC
inline Index outerStride() const { return m_matrix.outerStride(); }
EIGEN_DEVICE_FUNC
inline Index innerStride() const { return m_matrix.innerStride(); }
#ifdef EIGEN_TEST_EVALUATORS
/** \sa MatrixBase::operator+=() */
template<typename Other>
EIGEN_DEVICE_FUNC
TriangularView& operator+=(const DenseBase<Other>& other) {
internal::call_assignment_no_alias(*this, other.derived(), internal::add_assign_op<Scalar>());
return *this;
}
/** \sa MatrixBase::operator-=() */
template<typename Other>
EIGEN_DEVICE_FUNC
TriangularView& operator-=(const DenseBase<Other>& other) {
internal::call_assignment_no_alias(*this, other.derived(), internal::sub_assign_op<Scalar>());
return *this;
}
#else
/** \sa MatrixBase::operator+=() */
template<typename Other>
EIGEN_DEVICE_FUNC
TriangularView& operator+=(const DenseBase<Other>& other) { return *this = m_matrix + other.derived(); }
/** \sa MatrixBase::operator-=() */
template<typename Other>
EIGEN_DEVICE_FUNC
TriangularView& operator-=(const DenseBase<Other>& other) { return *this = m_matrix - other.derived(); }
#endif
/** \sa MatrixBase::operator*=() */
EIGEN_DEVICE_FUNC
TriangularView& operator*=(const typename internal::traits<MatrixType>::Scalar& other) { return *this = m_matrix * other; }
/** \sa MatrixBase::operator/=() */
EIGEN_DEVICE_FUNC
TriangularView& operator/=(const typename internal::traits<MatrixType>::Scalar& other) { return *this = m_matrix / other; }
/** \sa MatrixBase::fill() */
EIGEN_DEVICE_FUNC
void fill(const Scalar& value) { setConstant(value); }
/** \sa MatrixBase::setConstant() */
EIGEN_DEVICE_FUNC
TriangularView& setConstant(const Scalar& value)
{ return *this = MatrixType::Constant(rows(), cols(), value); }
/** \sa MatrixBase::setZero() */
EIGEN_DEVICE_FUNC
TriangularView& setZero() { return setConstant(Scalar(0)); }
/** \sa MatrixBase::setOnes() */
EIGEN_DEVICE_FUNC
TriangularView& setOnes() { return setConstant(Scalar(1)); }
/** \sa MatrixBase::coeff()
* \warning the coordinates must fit into the referenced triangular part
*/
EIGEN_DEVICE_FUNC
inline Scalar coeff(Index row, Index col) const
{
Base::check_coordinates_internal(row, col);
return m_matrix.coeff(row, col);
}
/** \sa MatrixBase::coeffRef()
* \warning the coordinates must fit into the referenced triangular part
*/
EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index row, Index col)
{
Base::check_coordinates_internal(row, col);
return m_matrix.const_cast_derived().coeffRef(row, col);
}
EIGEN_DEVICE_FUNC
const MatrixTypeNestedCleaned& nestedExpression() const { return m_matrix; }
EIGEN_DEVICE_FUNC
MatrixTypeNestedCleaned& nestedExpression() { return *const_cast<MatrixTypeNestedCleaned*>(&m_matrix); }
/** Assigns a triangular matrix to a triangular part of a dense matrix */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
TriangularView& operator=(const TriangularBase<OtherDerived>& other);
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
TriangularView& operator=(const MatrixBase<OtherDerived>& other);
EIGEN_DEVICE_FUNC
TriangularView& operator=(const TriangularView& other)
{ return *this = other.nestedExpression(); }
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
void lazyAssign(const TriangularBase<OtherDerived>& other);
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
void lazyAssign(const MatrixBase<OtherDerived>& other);
/** \sa MatrixBase::conjugate() */
EIGEN_DEVICE_FUNC
inline TriangularView<MatrixConjugateReturnType,Mode> conjugate()
@ -360,78 +268,16 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
return m_matrix.transpose();
}
#ifdef EIGEN_TEST_EVALUATORS
/** Efficient triangular matrix times vector/matrix product */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
const Product<TriangularView,OtherDerived>
operator*(const MatrixBase<OtherDerived>& rhs) const
{
return Product<TriangularView,OtherDerived>(*this, rhs.derived());
}
/** Efficient vector/matrix times triangular matrix product */
template<typename OtherDerived> friend
EIGEN_DEVICE_FUNC
const Product<OtherDerived,TriangularView>
operator*(const MatrixBase<OtherDerived>& lhs, const TriangularView& rhs)
{
return Product<OtherDerived,TriangularView>(lhs.derived(),rhs);
}
#else // EIGEN_TEST_EVALUATORS
/** Efficient triangular matrix times vector/matrix product */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
TriangularProduct<Mode, true, MatrixType, false, OtherDerived, OtherDerived::ColsAtCompileTime==1>
operator*(const MatrixBase<OtherDerived>& rhs) const
{
return TriangularProduct
<Mode, true, MatrixType, false, OtherDerived, OtherDerived::ColsAtCompileTime==1>
(m_matrix, rhs.derived());
}
/** Efficient vector/matrix times triangular matrix product */
template<typename OtherDerived> friend
EIGEN_DEVICE_FUNC
TriangularProduct<Mode, false, OtherDerived, OtherDerived::RowsAtCompileTime==1, MatrixType, false>
operator*(const MatrixBase<OtherDerived>& lhs, const TriangularView& rhs)
{
return TriangularProduct
<Mode, false, OtherDerived, OtherDerived::RowsAtCompileTime==1, MatrixType, false>
(lhs.derived(),rhs.m_matrix);
}
#endif
template<int Side, typename Other>
EIGEN_DEVICE_FUNC
inline const internal::triangular_solve_retval<Side,TriangularView, Other>
solve(const MatrixBase<Other>& other) const;
template<int Side, typename OtherDerived>
EIGEN_DEVICE_FUNC
void solveInPlace(const MatrixBase<OtherDerived>& other) const;
#ifdef EIGEN_TEST_EVALUATORS
template<typename Other>
EIGEN_DEVICE_FUNC
inline const Solve<TriangularView, Other>
solve(const MatrixBase<Other>& other) const
{ return Solve<TriangularView, Other>(*this, other.derived()); }
#else // EIGEN_TEST_EVALUATORS
template<typename Other>
EIGEN_DEVICE_FUNC
inline const internal::triangular_solve_retval<OnTheLeft,TriangularView, Other>
solve(const MatrixBase<Other>& other) const
{ return solve<OnTheLeft>(other); }
using Base::solve;
#endif // EIGEN_TEST_EVALUATORS
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
void solveInPlace(const MatrixBase<OtherDerived>& other) const
{ return solveInPlace<OnTheLeft>(other); }
EIGEN_DEVICE_FUNC
const SelfAdjointView<MatrixTypeNestedNonRef,Mode> selfadjointView() const
{
@ -445,30 +291,6 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
return SelfAdjointView<MatrixTypeNestedNonRef,Mode>(m_matrix);
}
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
void swap(TriangularBase<OtherDerived> const & other)
{
#ifdef EIGEN_TEST_EVALUATORS
call_assignment(*this, other.const_cast_derived(), internal::swap_assign_op<Scalar>());
#else
TriangularView<SwapWrapper<MatrixType>,Mode>(const_cast<MatrixType&>(m_matrix)).lazyAssign(other.const_cast_derived().nestedExpression());
#endif
}
// TODO: this overload is ambiguous and it should be deprecated (Gael)
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
void swap(MatrixBase<OtherDerived> const & other)
{
#ifdef EIGEN_TEST_EVALUATORS
call_assignment(*this, other.const_cast_derived(), internal::swap_assign_op<Scalar>());
#else
SwapWrapper<MatrixType> swaper(const_cast<MatrixType&>(m_matrix));
TriangularView<SwapWrapper<MatrixType>,Mode>(swaper).lazyAssign(other.derived());
#endif
}
EIGEN_DEVICE_FUNC
Scalar determinant() const
{
@ -479,13 +301,227 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
else
return m_matrix.diagonal().prod();
}
protected:
MatrixTypeNested m_matrix;
};
template<typename _MatrixType, unsigned int _Mode> class TriangularViewImpl<_MatrixType,_Mode,Dense>
: public TriangularBase<TriangularView<_MatrixType, _Mode> >
{
public:
typedef TriangularView<_MatrixType, _Mode> TriangularViewType;
typedef TriangularBase<TriangularViewType> Base;
typedef typename internal::traits<TriangularViewType>::Scalar Scalar;
typedef _MatrixType MatrixType;
typedef typename MatrixType::PlainObject DenseMatrixType;
typedef DenseMatrixType PlainObject;
public:
using Base::evalToLazy;
using Base::derived;
typedef typename internal::traits<TriangularViewType>::StorageKind StorageKind;
typedef typename internal::traits<TriangularViewType>::Index Index;
enum {
Mode = _Mode,
Flags = internal::traits<TriangularViewType>::Flags
};
EIGEN_DEVICE_FUNC
inline Index outerStride() const { return derived().nestedExpression().outerStride(); }
EIGEN_DEVICE_FUNC
inline Index innerStride() const { return derived().nestedExpression().innerStride(); }
#ifdef EIGEN_TEST_EVALUATORS
/** \sa MatrixBase::operator+=() */
template<typename Other>
EIGEN_DEVICE_FUNC
TriangularViewType& operator+=(const DenseBase<Other>& other) {
internal::call_assignment_no_alias(derived(), other.derived(), internal::add_assign_op<Scalar>());
return derived();
}
/** \sa MatrixBase::operator-=() */
template<typename Other>
EIGEN_DEVICE_FUNC
TriangularViewType& operator-=(const DenseBase<Other>& other) {
internal::call_assignment_no_alias(derived(), other.derived(), internal::sub_assign_op<Scalar>());
return derived();
}
#else
/** \sa MatrixBase::operator+=() */
template<typename Other>
EIGEN_DEVICE_FUNC
TriangularViewType& operator+=(const DenseBase<Other>& other) { return *this = derived().nestedExpression() + other.derived(); }
/** \sa MatrixBase::operator-=() */
template<typename Other>
EIGEN_DEVICE_FUNC
TriangularViewType& operator-=(const DenseBase<Other>& other) { return *this = derived().nestedExpression() - other.derived(); }
#endif
/** \sa MatrixBase::operator*=() */
EIGEN_DEVICE_FUNC
TriangularViewType& operator*=(const typename internal::traits<MatrixType>::Scalar& other) { return *this = derived().nestedExpression() * other; }
/** \sa MatrixBase::operator/=() */
EIGEN_DEVICE_FUNC
TriangularViewType& operator/=(const typename internal::traits<MatrixType>::Scalar& other) { return *this = derived().nestedExpression() / other; }
/** \sa MatrixBase::fill() */
EIGEN_DEVICE_FUNC
void fill(const Scalar& value) { setConstant(value); }
/** \sa MatrixBase::setConstant() */
EIGEN_DEVICE_FUNC
TriangularViewType& setConstant(const Scalar& value)
{ return *this = MatrixType::Constant(derived().rows(), derived().cols(), value); }
/** \sa MatrixBase::setZero() */
EIGEN_DEVICE_FUNC
TriangularViewType& setZero() { return setConstant(Scalar(0)); }
/** \sa MatrixBase::setOnes() */
EIGEN_DEVICE_FUNC
TriangularViewType& setOnes() { return setConstant(Scalar(1)); }
/** \sa MatrixBase::coeff()
* \warning the coordinates must fit into the referenced triangular part
*/
EIGEN_DEVICE_FUNC
inline Scalar coeff(Index row, Index col) const
{
Base::check_coordinates_internal(row, col);
return derived().nestedExpression().coeff(row, col);
}
/** \sa MatrixBase::coeffRef()
* \warning the coordinates must fit into the referenced triangular part
*/
EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index row, Index col)
{
Base::check_coordinates_internal(row, col);
return derived().nestedExpression().const_cast_derived().coeffRef(row, col);
}
/** Assigns a triangular matrix to a triangular part of a dense matrix */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
TriangularViewType& operator=(const TriangularBase<OtherDerived>& other);
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
TriangularViewType& operator=(const MatrixBase<OtherDerived>& other);
EIGEN_DEVICE_FUNC
TriangularViewType& operator=(const TriangularViewImpl& other)
{ return *this = other.nestedExpression(); }
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
void lazyAssign(const TriangularBase<OtherDerived>& other);
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
void lazyAssign(const MatrixBase<OtherDerived>& other);
#ifdef EIGEN_TEST_EVALUATORS
/** Efficient triangular matrix times vector/matrix product */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
const Product<TriangularViewType,OtherDerived>
operator*(const MatrixBase<OtherDerived>& rhs) const
{
return Product<TriangularViewType,OtherDerived>(derived(), rhs.derived());
}
/** Efficient vector/matrix times triangular matrix product */
template<typename OtherDerived> friend
EIGEN_DEVICE_FUNC
const Product<OtherDerived,TriangularViewType>
operator*(const MatrixBase<OtherDerived>& lhs, const TriangularViewImpl& rhs)
{
return Product<OtherDerived,TriangularViewType>(lhs.derived(),rhs.derived());
}
#else // EIGEN_TEST_EVALUATORS
/** Efficient triangular matrix times vector/matrix product */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
TriangularProduct<Mode, true, MatrixType, false, OtherDerived, OtherDerived::ColsAtCompileTime==1>
operator*(const MatrixBase<OtherDerived>& rhs) const
{
return TriangularProduct
<Mode, true, MatrixType, false, OtherDerived, OtherDerived::ColsAtCompileTime==1>
(derived().nestedExpression(), rhs.derived());
}
/** Efficient vector/matrix times triangular matrix product */
template<typename OtherDerived> friend
EIGEN_DEVICE_FUNC
TriangularProduct<Mode, false, OtherDerived, OtherDerived::RowsAtCompileTime==1, MatrixType, false>
operator*(const MatrixBase<OtherDerived>& lhs, const TriangularVieImplw& rhs)
{
return TriangularProduct
<Mode, false, OtherDerived, OtherDerived::RowsAtCompileTime==1, MatrixType, false>
(lhs.derived(),rhs.nestedExpression());
}
#endif
template<int Side, typename Other>
EIGEN_DEVICE_FUNC
inline const internal::triangular_solve_retval<Side,TriangularViewType, Other>
solve(const MatrixBase<Other>& other) const;
template<int Side, typename OtherDerived>
EIGEN_DEVICE_FUNC
void solveInPlace(const MatrixBase<OtherDerived>& other) const;
#ifndef EIGEN_TEST_EVALUATORS
template<typename Other>
EIGEN_DEVICE_FUNC
inline const internal::triangular_solve_retval<OnTheLeft,TriangularViewType, Other>
solve(const MatrixBase<Other>& other) const
{ return solve<OnTheLeft>(other); }
#endif // EIGEN_TEST_EVALUATORS
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
void solveInPlace(const MatrixBase<OtherDerived>& other) const
{ return solveInPlace<OnTheLeft>(other); }
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
void swap(TriangularBase<OtherDerived> const & other)
{
#ifdef EIGEN_TEST_EVALUATORS
call_assignment(derived(), other.const_cast_derived(), internal::swap_assign_op<Scalar>());
#else
TriangularView<SwapWrapper<MatrixType>,Mode>(const_cast<MatrixType&>(derived().nestedExpression())).lazyAssign(other.const_cast_derived().nestedExpression());
#endif
}
// TODO: this overload is ambiguous and it should be deprecated (Gael)
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
void swap(MatrixBase<OtherDerived> const & other)
{
#ifdef EIGEN_TEST_EVALUATORS
call_assignment(derived(), other.const_cast_derived(), internal::swap_assign_op<Scalar>());
#else
SwapWrapper<MatrixType> swaper(const_cast<MatrixType&>(derived().nestedExpression()));
TriangularView<SwapWrapper<MatrixType>,Mode>(swaper).lazyAssign(other.derived());
#endif
}
#ifndef EIGEN_TEST_EVALUATORS
// TODO simplify the following:
template<typename ProductDerived, typename Lhs, typename Rhs>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TriangularView& operator=(const ProductBase<ProductDerived, Lhs,Rhs>& other)
EIGEN_STRONG_INLINE TriangularViewType& operator=(const ProductBase<ProductDerived, Lhs,Rhs>& other)
{
setZero();
return assignProduct(other,1);
@ -493,14 +529,14 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
template<typename ProductDerived, typename Lhs, typename Rhs>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TriangularView& operator+=(const ProductBase<ProductDerived, Lhs,Rhs>& other)
EIGEN_STRONG_INLINE TriangularViewType& operator+=(const ProductBase<ProductDerived, Lhs,Rhs>& other)
{
return assignProduct(other,1);
}
template<typename ProductDerived, typename Lhs, typename Rhs>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TriangularView& operator-=(const ProductBase<ProductDerived, Lhs,Rhs>& other)
EIGEN_STRONG_INLINE TriangularViewType& operator-=(const ProductBase<ProductDerived, Lhs,Rhs>& other)
{
return assignProduct(other,-1);
}
@ -508,7 +544,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
template<typename ProductDerived>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TriangularView& operator=(const ScaledProduct<ProductDerived>& other)
EIGEN_STRONG_INLINE TriangularViewType& operator=(const ScaledProduct<ProductDerived>& other)
{
setZero();
return assignProduct(other,other.alpha());
@ -516,14 +552,14 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
template<typename ProductDerived>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TriangularView& operator+=(const ScaledProduct<ProductDerived>& other)
EIGEN_STRONG_INLINE TriangularViewType& operator+=(const ScaledProduct<ProductDerived>& other)
{
return assignProduct(other,other.alpha());
}
template<typename ProductDerived>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TriangularView& operator-=(const ScaledProduct<ProductDerived>& other)
EIGEN_STRONG_INLINE TriangularViewType& operator-=(const ScaledProduct<ProductDerived>& other)
{
return assignProduct(other,-other.alpha());
}
@ -542,16 +578,15 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
template<typename ProductType>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TriangularView& _assignProduct(const ProductType& prod, const Scalar& alpha);
EIGEN_STRONG_INLINE TriangularViewType& _assignProduct(const ProductType& prod, const Scalar& alpha);
protected:
#else
protected:
template<typename ProductDerived, typename Lhs, typename Rhs>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TriangularView& assignProduct(const ProductBase<ProductDerived, Lhs,Rhs>& prod, const Scalar& alpha);
EIGEN_STRONG_INLINE TriangularViewType& assignProduct(const ProductBase<ProductDerived, Lhs,Rhs>& prod, const Scalar& alpha);
#endif
MatrixTypeNested m_matrix;
};
/***************************************************************************
@ -736,18 +771,18 @@ struct triangular_assignment_selector<Derived1, Derived2, UnitLower, Dynamic, Cl
template<typename MatrixType, unsigned int Mode>
template<typename OtherDerived>
inline TriangularView<MatrixType, Mode>&
TriangularView<MatrixType, Mode>::operator=(const MatrixBase<OtherDerived>& other)
TriangularViewImpl<MatrixType, Mode, Dense>::operator=(const MatrixBase<OtherDerived>& other)
{
internal::call_assignment_no_alias(*this, other.derived(), internal::assign_op<Scalar>());
return *this;
internal::call_assignment_no_alias(derived(), other.derived(), internal::assign_op<Scalar>());
return derived();
}
// FIXME should we keep that possibility
template<typename MatrixType, unsigned int Mode>
template<typename OtherDerived>
void TriangularView<MatrixType, Mode>::lazyAssign(const MatrixBase<OtherDerived>& other)
void TriangularViewImpl<MatrixType, Mode, Dense>::lazyAssign(const MatrixBase<OtherDerived>& other)
{
internal::call_assignment(this->noalias(), other.template triangularView<Mode>());
internal::call_assignment(derived().noalias(), other.template triangularView<Mode>());
}
@ -755,19 +790,19 @@ void TriangularView<MatrixType, Mode>::lazyAssign(const MatrixBase<OtherDerived>
template<typename MatrixType, unsigned int Mode>
template<typename OtherDerived>
inline TriangularView<MatrixType, Mode>&
TriangularView<MatrixType, Mode>::operator=(const TriangularBase<OtherDerived>& other)
TriangularViewImpl<MatrixType, Mode, Dense>::operator=(const TriangularBase<OtherDerived>& other)
{
eigen_assert(Mode == int(OtherDerived::Mode));
internal::call_assignment(*this, other.derived());
return *this;
internal::call_assignment(derived(), other.derived());
return derived();
}
template<typename MatrixType, unsigned int Mode>
template<typename OtherDerived>
void TriangularView<MatrixType, Mode>::lazyAssign(const TriangularBase<OtherDerived>& other)
void TriangularViewImpl<MatrixType, Mode, Dense>::lazyAssign(const TriangularBase<OtherDerived>& other)
{
eigen_assert(Mode == int(OtherDerived::Mode));
internal::call_assignment(this->noalias(), other.derived());
internal::call_assignment(derived().noalias(), other.derived());
}
#else
@ -776,7 +811,7 @@ void TriangularView<MatrixType, Mode>::lazyAssign(const TriangularBase<OtherDeri
template<typename MatrixType, unsigned int Mode>
template<typename OtherDerived>
inline TriangularView<MatrixType, Mode>&
TriangularView<MatrixType, Mode>::operator=(const MatrixBase<OtherDerived>& other)
TriangularViewImpl<MatrixType, Mode, Dense>::operator=(const MatrixBase<OtherDerived>& other)
{
if(OtherDerived::Flags & EvalBeforeAssigningBit)
{
@ -786,13 +821,13 @@ TriangularView<MatrixType, Mode>::operator=(const MatrixBase<OtherDerived>& othe
}
else
lazyAssign(other.derived());
return *this;
return derived();
}
// FIXME should we keep that possibility
template<typename MatrixType, unsigned int Mode>
template<typename OtherDerived>
void TriangularView<MatrixType, Mode>::lazyAssign(const MatrixBase<OtherDerived>& other)
void TriangularViewImp<MatrixType, Mode, Dense>::lazyAssign(const MatrixBase<OtherDerived>& other)
{
enum {
unroll = MatrixType::SizeAtCompileTime != Dynamic
@ -813,7 +848,7 @@ void TriangularView<MatrixType, Mode>::lazyAssign(const MatrixBase<OtherDerived>
template<typename MatrixType, unsigned int Mode>
template<typename OtherDerived>
inline TriangularView<MatrixType, Mode>&
TriangularView<MatrixType, Mode>::operator=(const TriangularBase<OtherDerived>& other)
TriangularViewImpl<MatrixType, Mode, Dense>::operator=(const TriangularBase<OtherDerived>& other)
{
eigen_assert(Mode == int(OtherDerived::Mode));
if(internal::traits<OtherDerived>::Flags & EvalBeforeAssigningBit)
@ -824,12 +859,12 @@ TriangularView<MatrixType, Mode>::operator=(const TriangularBase<OtherDerived>&
}
else
lazyAssign(other.derived().nestedExpression());
return *this;
return derived();
}
template<typename MatrixType, unsigned int Mode>
template<typename OtherDerived>
void TriangularView<MatrixType, Mode>::lazyAssign(const TriangularBase<OtherDerived>& other)
void TriangularViewImpl<MatrixType, Mode, Dense>::lazyAssign(const TriangularBase<OtherDerived>& other)
{
enum {
unroll = MatrixType::SizeAtCompileTime != Dynamic
@ -1000,7 +1035,7 @@ template<typename MatrixType, unsigned int Mode>
struct evaluator_traits<TriangularView<MatrixType,Mode> >
{
typedef typename storage_kind_to_evaluator_kind<typename MatrixType::StorageKind>::Kind Kind;
typedef TriangularShape Shape;
typedef typename glue_shapes<typename evaluator_traits<MatrixType>::Shape, TriangularShape>::type Shape;
// 1 if assignment A = B assumes aliasing when B is of type T and thus B needs to be evaluated into a
// temporary; 0 if not.

View File

@ -264,22 +264,23 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
#ifdef EIGEN_TEST_EVALUATORS
template<typename MatrixType, unsigned int UpLo>
template<typename ProductType>
TriangularView<MatrixType,UpLo>& TriangularView<MatrixType,UpLo>::_assignProduct(const ProductType& prod, const Scalar& alpha)
TriangularView<MatrixType,UpLo>& TriangularViewImpl<MatrixType,UpLo,Dense>::_assignProduct(const ProductType& prod, const Scalar& alpha)
{
eigen_assert(m_matrix.rows() == prod.rows() && m_matrix.cols() == prod.cols());
eigen_assert(derived().nestedExpression().rows() == prod.rows() && derived().cols() == prod.cols());
general_product_to_triangular_selector<MatrixType, ProductType, UpLo, internal::traits<ProductType>::InnerSize==1>::run(m_matrix.const_cast_derived(), prod, alpha);
general_product_to_triangular_selector<MatrixType, ProductType, UpLo, internal::traits<ProductType>::InnerSize==1>::run(derived().nestedExpression().const_cast_derived(), prod, alpha);
return *this;
}
#else
template<typename MatrixType, unsigned int UpLo>
template<typename ProductDerived, typename _Lhs, typename _Rhs>
TriangularView<MatrixType,UpLo>& TriangularView<MatrixType,UpLo>::assignProduct(const ProductBase<ProductDerived, _Lhs,_Rhs>& prod, const Scalar& alpha)
TriangularView<MatrixType,UpLo>& TriangularViewImpl<MatrixType,UpLo,Dense>::assignProduct(const ProductBase<ProductDerived, _Lhs,_Rhs>& prod, const Scalar& alpha)
{
eigen_assert(m_matrix.rows() == prod.rows() && m_matrix.cols() == prod.cols());
eigen_assert(derived().rows() == prod.rows() && derived().cols() == prod.cols());
general_product_to_triangular_selector<MatrixType, ProductDerived, UpLo, (_Lhs::ColsAtCompileTime==1) || (_Rhs::RowsAtCompileTime==1)>::run(m_matrix.const_cast_derived(), prod.derived(), alpha);
general_product_to_triangular_selector<MatrixType, ProductDerived, UpLo, (_Lhs::ColsAtCompileTime==1) || (_Rhs::RowsAtCompileTime==1)>
::run(derived().nestedExpression().const_cast_derived(), prod.derived(), alpha);
return *this;
}

View File

@ -368,10 +368,12 @@ EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,false,
/***************************************************************************
* Wrapper to product_triangular_matrix_matrix
***************************************************************************/
#ifndef EIGEN_TEST_EVALUATORS
template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> >
: traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs> >
{};
#endif
} // end namespace internal

View File

@ -157,6 +157,8 @@ EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,Con
* Wrapper to product_triangular_vector
***************************************************************************/
#ifndef EIGEN_TEST_EVALUATORS
template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
: traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
@ -166,7 +168,7 @@ template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
: traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
{};
#endif
template<int Mode,int StorageOrder>
struct trmv_selector;

View File

@ -450,13 +450,13 @@ struct MatrixXpr {};
struct ArrayXpr {};
// An evaluator must define its shape. By default, it can be one of the following:
struct DenseShape { static std::string debugName() { return "DenseShape"; } };
struct DiagonalShape { static std::string debugName() { return "DiagonalShape"; } };
struct BandShape { static std::string debugName() { return "BandShape"; } };
struct TriangularShape { static std::string debugName() { return "TriangularShape"; } };
struct SelfAdjointShape { static std::string debugName() { return "SelfAdjointShape"; } };
struct PermutationShape { static std::string debugName() { return "PermutationShape"; } };
struct SparseShape { static std::string debugName() { return "SparseShape"; } };
struct DenseShape { static std::string debugName() { return "DenseShape"; } };
struct DiagonalShape { static std::string debugName() { return "DiagonalShape"; } };
struct BandShape { static std::string debugName() { return "BandShape"; } };
struct TriangularShape { static std::string debugName() { return "TriangularShape"; } };
struct SelfAdjointShape { static std::string debugName() { return "SelfAdjointShape"; } };
struct PermutationShape { static std::string debugName() { return "PermutationShape"; } };
struct SparseShape { static std::string debugName() { return "SparseShape"; } };
} // end namespace Eigen

View File

@ -638,6 +638,9 @@ template<typename T> struct is_diagonal<DiagonalWrapper<T> >
template<typename T, int S> struct is_diagonal<DiagonalMatrix<T,S> >
{ enum { ret = true }; };
template<typename S1, typename S2> struct glue_shapes;
template<> struct glue_shapes<DenseShape,TriangularShape> { typedef TriangularShape type; };
} // end namespace internal
// we require Lhs and Rhs to have the same scalar type. Currently there is no example of a binary functor

View File

@ -253,8 +253,8 @@ template<typename _MatrixType, int _UpLo> struct traits<SimplicialLLT<_MatrixTyp
typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::Index Index;
typedef SparseMatrix<Scalar, ColMajor, Index> CholMatrixType;
typedef SparseTriangularView<CholMatrixType, Eigen::Lower> MatrixL;
typedef SparseTriangularView<typename CholMatrixType::AdjointReturnType, Eigen::Upper> MatrixU;
typedef TriangularView<CholMatrixType, Eigen::Lower> MatrixL;
typedef TriangularView<typename CholMatrixType::AdjointReturnType, Eigen::Upper> MatrixU;
static inline MatrixL getL(const MatrixType& m) { return m; }
static inline MatrixU getU(const MatrixType& m) { return m.adjoint(); }
};
@ -266,8 +266,8 @@ template<typename _MatrixType,int _UpLo> struct traits<SimplicialLDLT<_MatrixTyp
typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::Index Index;
typedef SparseMatrix<Scalar, ColMajor, Index> CholMatrixType;
typedef SparseTriangularView<CholMatrixType, Eigen::UnitLower> MatrixL;
typedef SparseTriangularView<typename CholMatrixType::AdjointReturnType, Eigen::UnitUpper> MatrixU;
typedef TriangularView<CholMatrixType, Eigen::UnitLower> MatrixL;
typedef TriangularView<typename CholMatrixType::AdjointReturnType, Eigen::UnitUpper> MatrixU;
static inline MatrixL getL(const MatrixType& m) { return m; }
static inline MatrixU getU(const MatrixType& m) { return m.adjoint(); }
};

View File

@ -176,6 +176,34 @@ class MappedSparseMatrix<Scalar,_Flags,_Index>::ReverseInnerIterator
const Index m_end;
};
#ifdef EIGEN_ENABLE_EVALUATORS
namespace internal {
template<typename _Scalar, int _Options, typename _Index>
struct evaluator<MappedSparseMatrix<_Scalar,_Options,_Index> >
: evaluator_base<MappedSparseMatrix<_Scalar,_Options,_Index> >
{
typedef MappedSparseMatrix<_Scalar,_Options,_Index> MappedSparseMatrixType;
typedef typename MappedSparseMatrixType::InnerIterator InnerIterator;
typedef typename MappedSparseMatrixType::ReverseInnerIterator ReverseInnerIterator;
enum {
CoeffReadCost = NumTraits<_Scalar>::ReadCost,
Flags = MappedSparseMatrixType::Flags
};
evaluator() : m_matrix(0) {}
evaluator(const MappedSparseMatrixType &mat) : m_matrix(&mat) {}
operator MappedSparseMatrixType&() { return m_matrix->const_cast_derived(); }
operator const MappedSparseMatrixType&() const { return *m_matrix; }
const MappedSparseMatrixType *m_matrix;
};
}
#endif
} // end namespace Eigen
#endif // EIGEN_MAPPED_SPARSEMATRIX_H

View File

@ -333,7 +333,7 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
Derived& operator*=(const SparseMatrixBase<OtherDerived>& other);
template<int Mode>
inline const SparseTriangularView<Derived, Mode> triangularView() const;
inline const TriangularView<Derived, Mode> triangularView() const;
template<unsigned int UpLo> inline const SparseSelfAdjointView<Derived, UpLo> selfadjointView() const;
template<unsigned int UpLo> inline SparseSelfAdjointView<Derived, UpLo> selfadjointView();

View File

@ -392,8 +392,6 @@ inline void sparse_selfadjoint_time_dense_product(const SparseLhsType& lhs, cons
res.row(j) += i.value() * rhs.row(j);
}
}
struct SparseSelfAdjointShape { static std::string debugName() { return "SparseSelfAdjointShape"; } };
// TODO currently a selfadjoint expression has the form SelfAdjointView<.,.>
// in the future selfadjoint-ness should be defined by the expression traits

View File

@ -15,15 +15,15 @@ namespace Eigen {
namespace internal {
template<typename MatrixType, int Mode>
struct traits<SparseTriangularView<MatrixType,Mode> >
: public traits<MatrixType>
{};
// template<typename MatrixType, int Mode>
// struct traits<SparseTriangularView<MatrixType,Mode> >
// : public traits<MatrixType>
// {};
} // namespace internal
template<typename MatrixType, int Mode> class SparseTriangularView
: public SparseMatrixBase<SparseTriangularView<MatrixType,Mode> >
template<typename MatrixType, unsigned int Mode> class TriangularViewImpl<MatrixType,Mode,Sparse>
: public SparseMatrixBase<TriangularView<MatrixType,Mode> >
{
enum { SkipFirst = ((Mode&Lower) && !(MatrixType::Flags&RowMajorBit))
|| ((Mode&Upper) && (MatrixType::Flags&RowMajorBit)),
@ -31,45 +31,51 @@ template<typename MatrixType, int Mode> class SparseTriangularView
SkipDiag = (Mode&ZeroDiag) ? 1 : 0,
HasUnitDiag = (Mode&UnitDiag) ? 1 : 0
};
typedef TriangularView<MatrixType,Mode> TriangularViewType;
protected:
// dummy solve function to make TriangularView happy.
void solve() const;
public:
EIGEN_SPARSE_PUBLIC_INTERFACE(SparseTriangularView)
EIGEN_SPARSE_PUBLIC_INTERFACE(TriangularViewType)
class InnerIterator;
class ReverseInnerIterator;
inline Index rows() const { return m_matrix.rows(); }
inline Index cols() const { return m_matrix.cols(); }
typedef typename MatrixType::Nested MatrixTypeNested;
typedef typename internal::remove_reference<MatrixTypeNested>::type MatrixTypeNestedNonRef;
typedef typename internal::remove_all<MatrixTypeNested>::type MatrixTypeNestedCleaned;
inline SparseTriangularView(const MatrixType& matrix) : m_matrix(matrix) {}
/** \internal */
inline const MatrixTypeNestedCleaned& nestedExpression() const { return m_matrix; }
#ifndef EIGEN_TEST_EVALUATORS
template<typename OtherDerived>
typename internal::plain_matrix_type_column_major<OtherDerived>::type
solve(const MatrixBase<OtherDerived>& other) const;
#else // EIGEN_TEST_EVALUATORS
template<typename RhsType, typename DstType>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE void _solve_impl(const RhsType &rhs, DstType &dst) const {
if(!(internal::is_same<RhsType,DstType>::value && internal::extract_data(dst) == internal::extract_data(rhs)))
dst = rhs;
this->template solveInPlace(dst);
}
#endif // EIGEN_TEST_EVALUATORS
template<typename OtherDerived> void solveInPlace(MatrixBase<OtherDerived>& other) const;
template<typename OtherDerived> void solveInPlace(SparseMatrixBase<OtherDerived>& other) const;
protected:
MatrixTypeNested m_matrix;
};
template<typename MatrixType, int Mode>
class SparseTriangularView<MatrixType,Mode>::InnerIterator : public MatrixTypeNestedCleaned::InnerIterator
template<typename MatrixType, unsigned int Mode>
class TriangularViewImpl<MatrixType,Mode,Sparse>::InnerIterator : public MatrixTypeNestedCleaned::InnerIterator
{
typedef typename MatrixTypeNestedCleaned::InnerIterator Base;
typedef typename SparseTriangularView::Index Index;
typedef typename TriangularViewType::Index Index;
public:
EIGEN_STRONG_INLINE InnerIterator(const SparseTriangularView& view, Index outer)
EIGEN_STRONG_INLINE InnerIterator(const TriangularViewImpl& view, Index outer)
: Base(view.nestedExpression(), outer), m_returnOne(false)
{
if(SkipFirst)
@ -132,14 +138,14 @@ class SparseTriangularView<MatrixType,Mode>::InnerIterator : public MatrixTypeNe
bool m_returnOne;
};
template<typename MatrixType, int Mode>
class SparseTriangularView<MatrixType,Mode>::ReverseInnerIterator : public MatrixTypeNestedCleaned::ReverseInnerIterator
template<typename MatrixType, unsigned int Mode>
class TriangularViewImpl<MatrixType,Mode,Sparse>::ReverseInnerIterator : public MatrixTypeNestedCleaned::ReverseInnerIterator
{
typedef typename MatrixTypeNestedCleaned::ReverseInnerIterator Base;
typedef typename SparseTriangularView::Index Index;
typedef typename TriangularViewImpl::Index Index;
public:
EIGEN_STRONG_INLINE ReverseInnerIterator(const SparseTriangularView& view, Index outer)
EIGEN_STRONG_INLINE ReverseInnerIterator(const TriangularViewType& view, Index outer)
: Base(view.nestedExpression(), outer)
{
eigen_assert((!HasUnitDiag) && "ReverseInnerIterator does not support yet triangular views with a unit diagonal");
@ -168,7 +174,7 @@ class SparseTriangularView<MatrixType,Mode>::ReverseInnerIterator : public Matri
template<typename Derived>
template<int Mode>
inline const SparseTriangularView<Derived, Mode>
inline const TriangularView<Derived, Mode>
SparseMatrixBase<Derived>::triangularView() const
{
return derived();

View File

@ -94,7 +94,6 @@ template<typename _Scalar, int _Flags = 0, typename _Index = int> class Dynamic
template<typename _Scalar, int _Flags = 0, typename _Index = int> class SparseVector;
template<typename _Scalar, int _Flags = 0, typename _Index = int> class MappedSparseMatrix;
template<typename MatrixType, int Mode> class SparseTriangularView;
template<typename MatrixType, unsigned int UpLo> class SparseSelfAdjointView;
template<typename Lhs, typename Rhs> class SparseDiagonalProduct;
template<typename MatrixType> class SparseView;
@ -163,6 +162,12 @@ struct generic_xpr_base<Derived, MatrixXpr, Sparse>
typedef SparseMatrixBase<Derived> type;
};
struct SparseTriangularShape { static std::string debugName() { return "SparseTriangularShape"; } };
struct SparseSelfAdjointShape { static std::string debugName() { return "SparseSelfAdjointShape"; } };
template<> struct glue_shapes<SparseShape,SelfAdjointShape> { typedef SparseSelfAdjointShape type; };
template<> struct glue_shapes<SparseShape,TriangularShape > { typedef SparseTriangularShape type; };
} // end namespace internal
/** \ingroup SparseCore_Module

View File

@ -23,6 +23,7 @@ template<typename Lhs, typename Rhs, int Mode,
int StorageOrder = int(traits<Lhs>::Flags) & RowMajorBit>
struct sparse_solve_triangular_selector;
#ifndef EIGEN_TEST_EVALUATORS
// forward substitution, row-major
template<typename Lhs, typename Rhs, int Mode>
struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor>
@ -163,13 +164,166 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor>
}
};
#else // EIGEN_TEST_EVALUATORS
// forward substitution, row-major
template<typename Lhs, typename Rhs, int Mode>
struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Lhs::Index Index;
typedef typename evaluator<Lhs>::type LhsEval;
typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
static void run(const Lhs& lhs, Rhs& other)
{
LhsEval lhsEval(lhs);
for(Index col=0 ; col<other.cols() ; ++col)
{
for(Index i=0; i<lhs.rows(); ++i)
{
Scalar tmp = other.coeff(i,col);
Scalar lastVal(0);
Index lastIndex = 0;
for(LhsIterator it(lhsEval, i); it; ++it)
{
lastVal = it.value();
lastIndex = it.index();
if(lastIndex==i)
break;
tmp -= lastVal * other.coeff(lastIndex,col);
}
if (Mode & UnitDiag)
other.coeffRef(i,col) = tmp;
else
{
eigen_assert(lastIndex==i);
other.coeffRef(i,col) = tmp/lastVal;
}
}
}
}
};
// backward substitution, row-major
template<typename Lhs, typename Rhs, int Mode>
struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,RowMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Lhs::Index Index;
typedef typename evaluator<Lhs>::type LhsEval;
typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
static void run(const Lhs& lhs, Rhs& other)
{
LhsEval lhsEval(lhs);
for(Index col=0 ; col<other.cols() ; ++col)
{
for(Index i=lhs.rows()-1 ; i>=0 ; --i)
{
Scalar tmp = other.coeff(i,col);
Scalar l_ii = 0;
LhsIterator it(lhsEval, i);
while(it && it.index()<i)
++it;
if(!(Mode & UnitDiag))
{
eigen_assert(it && it.index()==i);
l_ii = it.value();
++it;
}
else if (it && it.index() == i)
++it;
for(; it; ++it)
{
tmp -= it.value() * other.coeff(it.index(),col);
}
if (Mode & UnitDiag) other.coeffRef(i,col) = tmp;
else other.coeffRef(i,col) = tmp/l_ii;
}
}
}
};
// forward substitution, col-major
template<typename Lhs, typename Rhs, int Mode>
struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,ColMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Lhs::Index Index;
typedef typename evaluator<Lhs>::type LhsEval;
typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
static void run(const Lhs& lhs, Rhs& other)
{
LhsEval lhsEval(lhs);
for(Index col=0 ; col<other.cols() ; ++col)
{
for(Index i=0; i<lhs.cols(); ++i)
{
Scalar& tmp = other.coeffRef(i,col);
if (tmp!=Scalar(0)) // optimization when other is actually sparse
{
LhsIterator it(lhsEval, i);
while(it && it.index()<i)
++it;
if(!(Mode & UnitDiag))
{
eigen_assert(it && it.index()==i);
tmp /= it.value();
}
if (it && it.index()==i)
++it;
for(; it; ++it)
other.coeffRef(it.index(), col) -= tmp * it.value();
}
}
}
}
};
// backward substitution, col-major
template<typename Lhs, typename Rhs, int Mode>
struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Lhs::Index Index;
typedef typename evaluator<Lhs>::type LhsEval;
typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
static void run(const Lhs& lhs, Rhs& other)
{
LhsEval lhsEval(lhs);
for(Index col=0 ; col<other.cols() ; ++col)
{
for(Index i=lhs.cols()-1; i>=0; --i)
{
Scalar& tmp = other.coeffRef(i,col);
if (tmp!=Scalar(0)) // optimization when other is actually sparse
{
if(!(Mode & UnitDiag))
{
// TODO replace this by a binary search. make sure the binary search is safe for partially sorted elements
LhsIterator it(lhsEval, i);
while(it && it.index()!=i)
++it;
eigen_assert(it && it.index()==i);
other.coeffRef(i,col) /= it.value();
}
LhsIterator it(lhsEval, i);
for(; it && it.index()<i; ++it)
other.coeffRef(it.index(), col) -= tmp * it.value();
}
}
}
}
};
#endif // EIGEN_TEST_EVALUATORS
} // end namespace internal
template<typename ExpressionType,int Mode>
template<typename ExpressionType,unsigned int Mode>
template<typename OtherDerived>
void SparseTriangularView<ExpressionType,Mode>::solveInPlace(MatrixBase<OtherDerived>& other) const
void TriangularViewImpl<ExpressionType,Mode,Sparse>::solveInPlace(MatrixBase<OtherDerived>& other) const
{
eigen_assert(m_matrix.cols() == m_matrix.rows() && m_matrix.cols() == other.rows());
eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows());
eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
@ -178,21 +332,23 @@ void SparseTriangularView<ExpressionType,Mode>::solveInPlace(MatrixBase<OtherDer
typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
OtherCopy otherCopy(other.derived());
internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(m_matrix, otherCopy);
internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(derived().nestedExpression(), otherCopy);
if (copy)
other = otherCopy;
}
template<typename ExpressionType,int Mode>
#ifndef EIGEN_TEST_EVALUATORS
template<typename ExpressionType,unsigned int Mode>
template<typename OtherDerived>
typename internal::plain_matrix_type_column_major<OtherDerived>::type
SparseTriangularView<ExpressionType,Mode>::solve(const MatrixBase<OtherDerived>& other) const
TriangularViewImpl<ExpressionType,Mode,Sparse>::solve(const MatrixBase<OtherDerived>& other) const
{
typename internal::plain_matrix_type_column_major<OtherDerived>::type res(other);
solveInPlace(res);
return res;
}
#endif
// pure sparse path
@ -290,11 +446,11 @@ struct sparse_solve_triangular_sparse_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
} // end namespace internal
template<typename ExpressionType,int Mode>
template<typename ExpressionType,unsigned int Mode>
template<typename OtherDerived>
void SparseTriangularView<ExpressionType,Mode>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const
void TriangularViewImpl<ExpressionType,Mode,Sparse>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const
{
eigen_assert(m_matrix.cols() == m_matrix.rows() && m_matrix.cols() == other.rows());
eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows());
eigen_assert( (!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
// enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
@ -303,7 +459,7 @@ void SparseTriangularView<ExpressionType,Mode>::solveInPlace(SparseMatrixBase<Ot
// typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
// OtherCopy otherCopy(other.derived());
internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(m_matrix, other.derived());
internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(derived().nestedExpression(), other.derived());
// if (copy)
// other = otherCopy;