fix 168 : now TriangularView::solve returns by value making TriangularView::solveInPlace less important.

Also fix the very outdated documentation of this function.
This commit is contained in:
Gael Guennebaud 2011-02-01 17:21:20 +01:00
parent 59af20b390
commit 8915d5bd22
3 changed files with 65 additions and 30 deletions

View File

@ -173,8 +173,6 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,CompleteUnrolling,1> {
***************************************************************************/
/** "in-place" version of TriangularView::solve() where the result is written in \a other
*
*
*
* \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here.
* This function will const_cast it, so constness isn't honored here.
@ -205,43 +203,68 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<OtherDerived
/** \returns the product of the inverse of \c *this with \a other, \a *this being triangular.
*
* This function computes the inverse-matrix matrix product inverse(\c *this) * \a other if
* \a Side==OnTheLeft (the default), or the right-inverse-multiply \a other * inverse(\c *this) if
* \a Side==OnTheRight.
*
*
* This function computes the inverse-matrix matrix product inverse(\c *this) * \a other.
* The matrix \c *this must be triangular and invertible (i.e., all the coefficients of the
* diagonal must be non zero). It works as a forward (resp. backward) substitution if \c *this
* is an upper (resp. lower) triangular matrix.
*
* It is required that \c *this be marked as either an upper or a lower triangular matrix, which
* can be done by marked(), and that is automatically the case with expressions such as those returned
* by extract().
*
* Example: \include MatrixBase_marked.cpp
* Output: \verbinclude MatrixBase_marked.out
*
* This function is essentially a wrapper to the faster solveTriangularInPlace() function creating
* a temporary copy of \a other, calling solveTriangularInPlace() on the copy and returning it.
* Therefore, if \a other is not needed anymore, it is quite faster to call solveTriangularInPlace()
* instead of solveTriangular().
* This function returns an expression of the inverse-multiply and can works in-place if it is assigned
* to the same matrix or vector \a other.
*
* For users coming from BLAS, this function (and more specifically solveTriangularInPlace()) offer
* For users coming from BLAS, this function (and more specifically solveInPlace()) offer
* all the operations supported by the \c *TRSV and \c *TRSM BLAS routines.
*
* \b Tips: to perform a \em "right-inverse-multiply" you can simply transpose the operation, e.g.:
* \code
* M * T^1 <=> T.transpose().solveInPlace(M.transpose());
* \endcode
*
* \sa TriangularView::solveInPlace()
*/
template<typename Derived, unsigned int Mode>
template<int Side, typename RhsDerived>
typename internal::plain_matrix_type_column_major<RhsDerived>::type
TriangularView<Derived,Mode>::solve(const MatrixBase<RhsDerived>& rhs) const
template<int Side, typename Other>
const internal::triangular_solve_retval<Side,TriangularView<Derived,Mode>,Other>
TriangularView<Derived,Mode>::solve(const MatrixBase<Other>& other) const
{
typename internal::plain_matrix_type_column_major<RhsDerived>::type res(rhs);
solveInPlace<Side>(res);
return res;
return internal::triangular_solve_retval<Side,TriangularView,Other>(*this, other.derived());
}
namespace internal {
template<int Side, typename TriangularType, typename Rhs>
struct traits<triangular_solve_retval<Side, TriangularType, Rhs> >
{
typedef typename internal::plain_matrix_type_column_major<Rhs>::type ReturnType;
};
template<int Side, typename TriangularType, typename Rhs> struct triangular_solve_retval
: public ReturnByValue<triangular_solve_retval<Side, TriangularType, Rhs> >
{
typedef typename remove_all<typename Rhs::Nested>::type RhsNestedCleaned;
typedef ReturnByValue<triangular_solve_retval> Base;
typedef typename Base::Index Index;
triangular_solve_retval(const TriangularType& tri, const Rhs& rhs)
: m_triangularMatrix(tri), m_rhs(rhs)
{}
inline Index rows() const { return m_rhs.rows(); }
inline Index cols() const { return m_rhs.cols(); }
template<typename Dest> inline void evalTo(Dest& dst) const
{
if(!(is_same<RhsNestedCleaned,Dest>::value && extract_data(dst) == extract_data(m_rhs)))
dst = m_rhs;
m_triangularMatrix.template solveInPlace<Side>(dst);
}
protected:
const TriangularType& m_triangularMatrix;
const typename Rhs::Nested m_rhs;
};
} // namespace internal
#endif // EIGEN_SOLVETRIANGULAR_H

View File

@ -26,6 +26,12 @@
#ifndef EIGEN_TRIANGULARMATRIX_H
#define EIGEN_TRIANGULARMATRIX_H
namespace internal {
template<int Side, typename TriangularType, typename Rhs> struct triangular_solve_retval;
}
/** \internal
*
* \class TriangularBase
@ -332,16 +338,16 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
}
#endif // EIGEN2_SUPPORT
template<int Side, typename OtherDerived>
typename internal::plain_matrix_type_column_major<OtherDerived>::type
solve(const MatrixBase<OtherDerived>& other) const;
template<int Side, typename Other>
inline const internal::triangular_solve_retval<Side,TriangularView, Other>
solve(const MatrixBase<Other>& other) const;
template<int Side, typename OtherDerived>
void solveInPlace(const MatrixBase<OtherDerived>& other) const;
template<typename OtherDerived>
typename internal::plain_matrix_type_column_major<OtherDerived>::type
solve(const MatrixBase<OtherDerived>& other) const
template<typename Other>
inline const internal::triangular_solve_retval<OnTheLeft,TriangularView, Other>
solve(const MatrixBase<Other>& other) const
{ return solve<OnTheLeft>(other); }
template<typename OtherDerived>

View File

@ -28,12 +28,18 @@
(XB).setRandom(); ref = (XB); \
(TRI).solveInPlace(XB); \
VERIFY_IS_APPROX((TRI).toDenseMatrix() * (XB), ref); \
(XB).setRandom(); ref = (XB); \
(XB) = (TRI).solve(XB); \
VERIFY_IS_APPROX((TRI).toDenseMatrix() * (XB), ref); \
}
#define VERIFY_TRSM_ONTHERIGHT(TRI,XB) { \
(XB).setRandom(); ref = (XB); \
(TRI).transpose().template solveInPlace<OnTheRight>(XB.transpose()); \
VERIFY_IS_APPROX((XB).transpose() * (TRI).transpose().toDenseMatrix(), ref.transpose()); \
(XB).setRandom(); ref = (XB); \
(XB).transpose() = (TRI).transpose().template solve<OnTheRight>(XB.transpose()); \
VERIFY_IS_APPROX((XB).transpose() * (TRI).transpose().toDenseMatrix(), ref.transpose()); \
}
template<typename Scalar,int Size, int Cols> void trsolve(int size=Size,int cols=Cols)