Make MatrixFunction use nested_eval instead of nested

This commit is contained in:
Gael Guennebaud 2014-09-18 17:31:17 +02:00
parent 060e835ee9
commit 62bce6e5e6
5 changed files with 25 additions and 20 deletions

View File

@ -392,14 +392,15 @@ template<typename Derived> struct MatrixExponentialReturnValue
template <typename ResultType>
inline void evalTo(ResultType& result) const
{
internal::matrix_exp_compute(m_src, result);
const typename internal::nested_eval<Derived, 10>::type tmp(m_src);
internal::matrix_exp_compute(tmp, result);
}
Index rows() const { return m_src.rows(); }
Index cols() const { return m_src.cols(); }
protected:
const typename internal::nested<Derived, 10>::type m_src;
const typename internal::nested<Derived>::type m_src;
};
namespace internal {

View File

@ -485,7 +485,7 @@ template<typename Derived> class MatrixFunctionReturnValue
typedef typename internal::stem_function<Scalar>::type StemFunction;
protected:
typedef typename internal::nested<Derived, 10>::type DerivedNested;
typedef typename internal::nested<Derived>::type DerivedNested;
public:
@ -503,18 +503,19 @@ template<typename Derived> class MatrixFunctionReturnValue
template <typename ResultType>
inline void evalTo(ResultType& result) const
{
typedef typename internal::remove_all<DerivedNested>::type DerivedNestedClean;
typedef internal::traits<DerivedNestedClean> Traits;
typedef typename internal::nested_eval<Derived, 10>::type NestedEvalType;
typedef typename internal::remove_all<NestedEvalType>::type NestedEvalTypeClean;
typedef internal::traits<NestedEvalTypeClean> Traits;
static const int RowsAtCompileTime = Traits::RowsAtCompileTime;
static const int ColsAtCompileTime = Traits::ColsAtCompileTime;
static const int Options = DerivedNestedClean::Options;
static const int Options = NestedEvalTypeClean::Options;
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
typedef Matrix<ComplexScalar, Dynamic, Dynamic, Options, RowsAtCompileTime, ColsAtCompileTime> DynMatrixType;
typedef internal::MatrixFunctionAtomic<DynMatrixType> AtomicType;
AtomicType atomic(m_f);
internal::matrix_function_compute<DerivedNestedClean>::run(m_A, atomic, result);
internal::matrix_function_compute<NestedEvalTypeClean>::run(m_A, atomic, result);
}
Index rows() const { return m_A.rows(); }

View File

@ -310,7 +310,7 @@ public:
typedef typename Derived::Index Index;
protected:
typedef typename internal::nested<Derived, 10>::type DerivedNested;
typedef typename internal::nested<Derived>::type DerivedNested;
public:
@ -327,17 +327,18 @@ public:
template <typename ResultType>
inline void evalTo(ResultType& result) const
{
typedef typename internal::remove_all<DerivedNested>::type DerivedNestedClean;
typedef internal::traits<DerivedNestedClean> Traits;
typedef typename internal::nested_eval<Derived, 10>::type DerivedEvalType;
typedef typename internal::remove_all<DerivedEvalType>::type DerivedEvalTypeClean;
typedef internal::traits<DerivedEvalTypeClean> Traits;
static const int RowsAtCompileTime = Traits::RowsAtCompileTime;
static const int ColsAtCompileTime = Traits::ColsAtCompileTime;
static const int Options = DerivedNestedClean::Options;
static const int Options = DerivedEvalTypeClean::Options;
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
typedef Matrix<ComplexScalar, Dynamic, Dynamic, Options, RowsAtCompileTime, ColsAtCompileTime> DynMatrixType;
typedef internal::MatrixLogarithmAtomic<DynMatrixType> AtomicType;
AtomicType atomic;
internal::matrix_function_compute<DerivedNestedClean>::run(m_A, atomic, result);
internal::matrix_function_compute<DerivedEvalTypeClean>::run(m_A, atomic, result);
}
Index rows() const { return m_A.rows(); }

View File

@ -320,7 +320,7 @@ template<typename Derived> class MatrixSquareRootReturnValue
{
protected:
typedef typename Derived::Index Index;
typedef typename internal::nested<Derived, 10>::type DerivedNested;
typedef typename internal::nested<Derived>::type DerivedNested;
public:
/** \brief Constructor.
@ -338,8 +338,10 @@ template<typename Derived> class MatrixSquareRootReturnValue
template <typename ResultType>
inline void evalTo(ResultType& result) const
{
typedef typename internal::remove_all<DerivedNested>::type DerivedNestedClean;
internal::matrix_sqrt_compute<DerivedNestedClean>::run(m_src, result);
typedef typename internal::nested_eval<Derived, 10>::type DerivedEvalType;
typedef typename internal::remove_all<DerivedEvalType>::type DerivedEvalTypeClean;
DerivedEvalType tmp(m_src);
internal::matrix_sqrt_compute<DerivedEvalTypeClean>::run(tmp, result);
}
Index rows() const { return m_src.rows(); }

View File

@ -36,11 +36,11 @@ if (NOT CMAKE_CXX_COMPILER MATCHES "clang\\+\\+$")
ei_add_test(BVH)
endif()
# TODO ei_add_test(matrix_exponential)
# TODO ei_add_test(matrix_function)
# TODO ei_add_test(matrix_power)
# TODO ei_add_test(matrix_square_root)
# TODO ei_add_test(alignedvector3)
ei_add_test(matrix_exponential)
ei_add_test(matrix_function)
ei_add_test(matrix_power)
ei_add_test(matrix_square_root)
ei_add_test(alignedvector3)
ei_add_test(FFT)