mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-06 14:14:46 +08:00
By-pass ProductBase for triangular and selfadjoint products and get rid of ProductBase
This commit is contained in:
parent
d67548f345
commit
c98881e130
@ -688,7 +688,6 @@ template<> struct gemv_selector<OnTheRight,RowMajor,true>
|
||||
typedef typename Lhs::Scalar LhsScalar;
|
||||
typedef typename Rhs::Scalar RhsScalar;
|
||||
typedef typename Dest::Scalar ResScalar;
|
||||
typedef typename Dest::RealScalar RealScalar;
|
||||
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
|
@ -11,6 +11,8 @@
|
||||
#define EIGEN_PRODUCTBASE_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
|
||||
/** \class ProductBase
|
||||
* \ingroup Core_Module
|
||||
@ -174,8 +176,6 @@ class ProductBase : public MatrixBase<Derived>
|
||||
mutable PlainObject m_result;
|
||||
};
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
|
||||
// here we need to overload the nested rule for products
|
||||
// such that the nested type is a const reference to a plain matrix
|
||||
namespace internal {
|
||||
|
@ -532,6 +532,10 @@ struct etor_product_packet_impl<ColMajor, Dynamic, Lhs, Rhs, Packet, LoadMode>
|
||||
/***************************************************************************
|
||||
* Triangular products
|
||||
***************************************************************************/
|
||||
template<int Mode, bool LhsIsTriangular,
|
||||
typename Lhs, bool LhsIsVector,
|
||||
typename Rhs, bool RhsIsVector>
|
||||
struct triangular_product_impl;
|
||||
|
||||
template<typename Lhs, typename Rhs, int ProductTag>
|
||||
struct generic_product_impl<Lhs,Rhs,TriangularShape,DenseShape,ProductTag>
|
||||
@ -542,8 +546,8 @@ struct generic_product_impl<Lhs,Rhs,TriangularShape,DenseShape,ProductTag>
|
||||
template<typename Dest>
|
||||
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
|
||||
{
|
||||
// TODO bypass TriangularProduct class
|
||||
TriangularProduct<Lhs::Mode,true,typename Lhs::MatrixType,false,Rhs, Rhs::IsVectorAtCompileTime>(lhs.nestedExpression(),rhs).scaleAndAddTo(dst, alpha);
|
||||
triangular_product_impl<Lhs::Mode,true,typename Lhs::MatrixType,false,Rhs, Rhs::IsVectorAtCompileTime>
|
||||
::run(dst, lhs.nestedExpression(), rhs, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
@ -576,8 +580,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,TriangularShape,ProductTag>
|
||||
template<typename Dest>
|
||||
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
|
||||
{
|
||||
// TODO bypass TriangularProduct class
|
||||
TriangularProduct<Rhs::Mode,false,Lhs,Lhs::IsVectorAtCompileTime, typename Rhs::MatrixType, false>(lhs,rhs.nestedExpression()).scaleAndAddTo(dst, alpha);
|
||||
triangular_product_impl<Rhs::Mode,false,Lhs,Lhs::IsVectorAtCompileTime, typename Rhs::MatrixType, false>::run(dst, lhs, rhs.nestedExpression(), alpha);
|
||||
}
|
||||
};
|
||||
|
||||
@ -605,6 +608,9 @@ protected:
|
||||
/***************************************************************************
|
||||
* SelfAdjoint products
|
||||
***************************************************************************/
|
||||
template <typename Lhs, int LhsMode, bool LhsIsVector,
|
||||
typename Rhs, int RhsMode, bool RhsIsVector>
|
||||
struct selfadjoint_product_impl;
|
||||
|
||||
template<typename Lhs, typename Rhs, int ProductTag>
|
||||
struct generic_product_impl<Lhs,Rhs,SelfAdjointShape,DenseShape,ProductTag>
|
||||
@ -615,8 +621,7 @@ struct generic_product_impl<Lhs,Rhs,SelfAdjointShape,DenseShape,ProductTag>
|
||||
template<typename Dest>
|
||||
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
|
||||
{
|
||||
// TODO bypass SelfadjointProductMatrix class
|
||||
SelfadjointProductMatrix<typename Lhs::MatrixType,Lhs::Mode,false,Rhs,0,Rhs::IsVectorAtCompileTime>(lhs.nestedExpression(),rhs).scaleAndAddTo(dst, alpha);
|
||||
selfadjoint_product_impl<typename Lhs::MatrixType,Lhs::Mode,false,Rhs,0,Rhs::IsVectorAtCompileTime>::run(dst, lhs.nestedExpression(), rhs, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
@ -649,8 +654,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,SelfAdjointShape,ProductTag>
|
||||
template<typename Dest>
|
||||
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
|
||||
{
|
||||
// TODO bypass SelfadjointProductMatrix class
|
||||
SelfadjointProductMatrix<Lhs,0,Lhs::IsVectorAtCompileTime,typename Rhs::MatrixType,Rhs::Mode,false>(lhs,rhs.nestedExpression()).scaleAndAddTo(dst, alpha);
|
||||
selfadjoint_product_impl<Lhs,0,Lhs::IsVectorAtCompileTime,typename Rhs::MatrixType,Rhs::Mode,false>::run(dst, lhs, rhs.nestedExpression(), alpha);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -45,9 +45,11 @@ struct traits<SelfAdjointView<MatrixType, UpLo> > : traits<MatrixType>
|
||||
};
|
||||
}
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
template <typename Lhs, int LhsMode, bool LhsIsVector,
|
||||
typename Rhs, int RhsMode, bool RhsIsVector>
|
||||
struct SelfadjointProductMatrix;
|
||||
#endif
|
||||
|
||||
// FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ??
|
||||
template<typename _MatrixType, unsigned int UpLo> class SelfAdjointView
|
||||
|
@ -381,6 +381,7 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,f
|
||||
* Wrapper to product_selfadjoint_matrix
|
||||
***************************************************************************/
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
namespace internal {
|
||||
template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
|
||||
struct traits<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false> >
|
||||
@ -430,6 +431,57 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>
|
||||
);
|
||||
}
|
||||
};
|
||||
#else // EIGEN_TEST_EVALUATORS
|
||||
namespace internal {
|
||||
|
||||
template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
|
||||
struct selfadjoint_product_impl<Lhs,LhsMode,false,Rhs,RhsMode,false>
|
||||
{
|
||||
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
||||
typedef typename Product<Lhs,Rhs>::Index Index;
|
||||
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
|
||||
enum {
|
||||
LhsIsUpper = (LhsMode&(Upper|Lower))==Upper,
|
||||
LhsIsSelfAdjoint = (LhsMode&SelfAdjoint)==SelfAdjoint,
|
||||
RhsIsUpper = (RhsMode&(Upper|Lower))==Upper,
|
||||
RhsIsSelfAdjoint = (RhsMode&SelfAdjoint)==SelfAdjoint
|
||||
};
|
||||
|
||||
template<typename Dest>
|
||||
static void run(Dest &dst, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar& alpha)
|
||||
{
|
||||
eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols());
|
||||
|
||||
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
|
||||
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
|
||||
|
||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(a_rhs);
|
||||
|
||||
internal::product_selfadjoint_matrix<Scalar, Index,
|
||||
EIGEN_LOGICAL_XOR(LhsIsUpper,internal::traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, LhsIsSelfAdjoint,
|
||||
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsIsUpper,bool(LhsBlasTraits::NeedToConjugate)),
|
||||
EIGEN_LOGICAL_XOR(RhsIsUpper,internal::traits<Rhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, RhsIsSelfAdjoint,
|
||||
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsIsUpper,bool(RhsBlasTraits::NeedToConjugate)),
|
||||
internal::traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor>
|
||||
::run(
|
||||
lhs.rows(), rhs.cols(), // sizes
|
||||
&lhs.coeffRef(0,0), lhs.outerStride(), // lhs info
|
||||
&rhs.coeffRef(0,0), rhs.outerStride(), // rhs info
|
||||
&dst.coeffRef(0,0), dst.outerStride(), // result info
|
||||
actualAlpha // alpha
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
#endif
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
|
@ -168,6 +168,7 @@ EIGEN_DONT_INLINE void selfadjoint_matrix_vector_product<Scalar,Index,StorageOrd
|
||||
* Wrapper to product_selfadjoint_vector
|
||||
***************************************************************************/
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
namespace internal {
|
||||
template<typename Lhs, int LhsMode, typename Rhs>
|
||||
struct traits<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true> >
|
||||
@ -276,6 +277,109 @@ struct SelfadjointProductMatrix<Lhs,0,true,Rhs,RhsMode,false>
|
||||
}
|
||||
};
|
||||
|
||||
#else // EIGEN_TEST_EVALUATORS
|
||||
|
||||
namespace internal {
|
||||
|
||||
template<typename Lhs, int LhsMode, typename Rhs>
|
||||
struct selfadjoint_product_impl<Lhs,LhsMode,false,Rhs,0,true>
|
||||
{
|
||||
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
||||
typedef typename Product<Lhs,Rhs>::Index Index;
|
||||
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
|
||||
|
||||
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
|
||||
|
||||
enum { LhsUpLo = LhsMode&(Upper|Lower) };
|
||||
|
||||
template<typename Dest>
|
||||
static void run(Dest& dest, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar& alpha)
|
||||
{
|
||||
typedef typename Dest::Scalar ResScalar;
|
||||
typedef typename Rhs::Scalar RhsScalar;
|
||||
typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
|
||||
|
||||
eigen_assert(dest.rows()==a_lhs.rows() && dest.cols()==a_rhs.cols());
|
||||
|
||||
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
|
||||
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
|
||||
|
||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(a_rhs);
|
||||
|
||||
enum {
|
||||
EvalToDest = (Dest::InnerStrideAtCompileTime==1),
|
||||
UseRhs = (ActualRhsTypeCleaned::InnerStrideAtCompileTime==1)
|
||||
};
|
||||
|
||||
internal::gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,!EvalToDest> static_dest;
|
||||
internal::gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!UseRhs> static_rhs;
|
||||
|
||||
ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
|
||||
EvalToDest ? dest.data() : static_dest.data());
|
||||
|
||||
ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,rhs.size(),
|
||||
UseRhs ? const_cast<RhsScalar*>(rhs.data()) : static_rhs.data());
|
||||
|
||||
if(!EvalToDest)
|
||||
{
|
||||
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||
int size = dest.size();
|
||||
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||
#endif
|
||||
MappedDest(actualDestPtr, dest.size()) = dest;
|
||||
}
|
||||
|
||||
if(!UseRhs)
|
||||
{
|
||||
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||
int size = rhs.size();
|
||||
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||
#endif
|
||||
Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, rhs.size()) = rhs;
|
||||
}
|
||||
|
||||
|
||||
internal::selfadjoint_matrix_vector_product<Scalar, Index, (internal::traits<ActualLhsTypeCleaned>::Flags&RowMajorBit) ? RowMajor : ColMajor,
|
||||
int(LhsUpLo), bool(LhsBlasTraits::NeedToConjugate), bool(RhsBlasTraits::NeedToConjugate)>::run
|
||||
(
|
||||
lhs.rows(), // size
|
||||
&lhs.coeffRef(0,0), lhs.outerStride(), // lhs info
|
||||
actualRhsPtr, 1, // rhs info
|
||||
actualDestPtr, // result info
|
||||
actualAlpha // scale factor
|
||||
);
|
||||
|
||||
if(!EvalToDest)
|
||||
dest = MappedDest(actualDestPtr, dest.size());
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int RhsMode>
|
||||
struct selfadjoint_product_impl<Lhs,0,true,Rhs,RhsMode,false>
|
||||
{
|
||||
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
||||
enum { RhsUpLo = RhsMode&(Upper|Lower) };
|
||||
|
||||
template<typename Dest>
|
||||
static void run(Dest& dest, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar& alpha)
|
||||
{
|
||||
// let's simply transpose the product
|
||||
Transpose<Dest> destT(dest);
|
||||
selfadjoint_product_impl<Transpose<const Rhs>, int(RhsUpLo)==Upper ? Lower : Upper, false,
|
||||
Transpose<const Lhs>, 0, true>::run(destT, a_rhs.transpose(), a_lhs.transpose(), alpha);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
#endif // EIGEN_TEST_EVALUATORS
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_H
|
||||
|
@ -372,7 +372,6 @@ EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,false,
|
||||
/***************************************************************************
|
||||
* Wrapper to product_triangular_matrix_matrix
|
||||
***************************************************************************/
|
||||
|
||||
template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
|
||||
struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> >
|
||||
: traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs> >
|
||||
@ -380,6 +379,7 @@ struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> >
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
|
||||
struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
|
||||
: public ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs >
|
||||
@ -421,6 +421,57 @@ struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
|
||||
);
|
||||
}
|
||||
};
|
||||
#else // EIGEN_TEST_EVALUATORS
|
||||
namespace internal {
|
||||
template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
|
||||
struct triangular_product_impl<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
|
||||
{
|
||||
template<typename Dest> static void run(Dest& dst, const Lhs &a_lhs, const Rhs &a_rhs, const typename Dest::Scalar& alpha)
|
||||
{
|
||||
typedef typename Dest::Index Index;
|
||||
typedef typename Dest::Scalar Scalar;
|
||||
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
|
||||
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
|
||||
|
||||
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
|
||||
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
|
||||
|
||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(a_rhs);
|
||||
|
||||
typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar,
|
||||
Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxColsAtCompileTime,4> BlockingType;
|
||||
|
||||
enum { IsLower = (Mode&Lower) == Lower };
|
||||
Index stripedRows = ((!LhsIsTriangular) || (IsLower)) ? lhs.rows() : (std::min)(lhs.rows(),lhs.cols());
|
||||
Index stripedCols = ((LhsIsTriangular) || (!IsLower)) ? rhs.cols() : (std::min)(rhs.cols(),rhs.rows());
|
||||
Index stripedDepth = LhsIsTriangular ? ((!IsLower) ? lhs.cols() : (std::min)(lhs.cols(),lhs.rows()))
|
||||
: ((IsLower) ? rhs.rows() : (std::min)(rhs.rows(),rhs.cols()));
|
||||
|
||||
BlockingType blocking(stripedRows, stripedCols, stripedDepth);
|
||||
|
||||
internal::product_triangular_matrix_matrix<Scalar, Index,
|
||||
Mode, LhsIsTriangular,
|
||||
(internal::traits<ActualLhsTypeCleaned>::Flags&RowMajorBit) ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
|
||||
(internal::traits<ActualRhsTypeCleaned>::Flags&RowMajorBit) ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
|
||||
(internal::traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
||||
::run(
|
||||
stripedRows, stripedCols, stripedDepth, // sizes
|
||||
&lhs.coeffRef(0,0), lhs.outerStride(), // lhs info
|
||||
&rhs.coeffRef(0,0), rhs.outerStride(), // rhs info
|
||||
&dst.coeffRef(0,0), dst.outerStride(), // result info
|
||||
actualAlpha, blocking
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
#endif // EIGEN_TEST_EVALUATORS
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
|
@ -168,11 +168,12 @@ struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
|
||||
{};
|
||||
|
||||
|
||||
template<int StorageOrder>
|
||||
template<int Mode,int StorageOrder>
|
||||
struct trmv_selector;
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
template<int Mode, typename Lhs, typename Rhs>
|
||||
struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
|
||||
: public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
|
||||
@ -185,7 +186,7 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
|
||||
{
|
||||
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
|
||||
|
||||
internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
|
||||
internal::trmv_selector<Mode,(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(m_lhs, m_rhs, dst, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
@ -201,39 +202,71 @@ struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
|
||||
{
|
||||
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
|
||||
|
||||
typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
|
||||
Transpose<Dest> dstT(dst);
|
||||
internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
|
||||
TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
|
||||
internal::trmv_selector<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),
|
||||
(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>
|
||||
::run(m_rhs.transpose(),m_lhs.transpose(), dstT, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
#else // EIGEN_TEST_EVALUATORS
|
||||
namespace internal {
|
||||
|
||||
template<int Mode, typename Lhs, typename Rhs>
|
||||
struct triangular_product_impl<Mode,true,Lhs,false,Rhs,true>
|
||||
{
|
||||
template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
|
||||
{
|
||||
eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
|
||||
|
||||
internal::trmv_selector<Mode,(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(lhs, rhs, dst, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
template<int Mode, typename Lhs, typename Rhs>
|
||||
struct triangular_product_impl<Mode,false,Lhs,true,Rhs,false>
|
||||
{
|
||||
template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
|
||||
{
|
||||
eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
|
||||
|
||||
Transpose<Dest> dstT(dst);
|
||||
internal::trmv_selector<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),
|
||||
(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>
|
||||
::run(rhs.transpose(),lhs.transpose(), dstT, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
#endif // EIGEN_TEST_EVALUATORS
|
||||
|
||||
namespace internal {
|
||||
|
||||
// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
|
||||
|
||||
template<> struct trmv_selector<ColMajor>
|
||||
template<int Mode> struct trmv_selector<Mode,ColMajor>
|
||||
{
|
||||
template<int Mode, typename Lhs, typename Rhs, typename Dest>
|
||||
static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, const typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar& alpha)
|
||||
template<typename Lhs, typename Rhs, typename Dest>
|
||||
static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
|
||||
{
|
||||
typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
|
||||
typedef typename ProductType::Index Index;
|
||||
typedef typename ProductType::LhsScalar LhsScalar;
|
||||
typedef typename ProductType::RhsScalar RhsScalar;
|
||||
typedef typename ProductType::Scalar ResScalar;
|
||||
typedef typename ProductType::RealScalar RealScalar;
|
||||
typedef typename ProductType::ActualLhsType ActualLhsType;
|
||||
typedef typename ProductType::ActualRhsType ActualRhsType;
|
||||
typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
|
||||
typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
|
||||
typedef typename Dest::Index Index;
|
||||
typedef typename Lhs::Scalar LhsScalar;
|
||||
typedef typename Rhs::Scalar RhsScalar;
|
||||
typedef typename Dest::Scalar ResScalar;
|
||||
typedef typename Dest::RealScalar RealScalar;
|
||||
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
|
||||
typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
|
||||
|
||||
typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
|
||||
typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
|
||||
typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
|
||||
typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
|
||||
|
||||
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
|
||||
* RhsBlasTraits::extractScalarFactor(prod.rhs());
|
||||
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(rhs);
|
||||
|
||||
enum {
|
||||
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
|
||||
@ -288,33 +321,33 @@ template<> struct trmv_selector<ColMajor>
|
||||
}
|
||||
};
|
||||
|
||||
template<> struct trmv_selector<RowMajor>
|
||||
template<int Mode> struct trmv_selector<Mode,RowMajor>
|
||||
{
|
||||
template<int Mode, typename Lhs, typename Rhs, typename Dest>
|
||||
static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, const typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar& alpha)
|
||||
template<typename Lhs, typename Rhs, typename Dest>
|
||||
static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
|
||||
{
|
||||
typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
|
||||
typedef typename ProductType::LhsScalar LhsScalar;
|
||||
typedef typename ProductType::RhsScalar RhsScalar;
|
||||
typedef typename ProductType::Scalar ResScalar;
|
||||
typedef typename ProductType::Index Index;
|
||||
typedef typename ProductType::ActualLhsType ActualLhsType;
|
||||
typedef typename ProductType::ActualRhsType ActualRhsType;
|
||||
typedef typename ProductType::_ActualRhsType _ActualRhsType;
|
||||
typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
|
||||
typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
|
||||
typedef typename Dest::Index Index;
|
||||
typedef typename Lhs::Scalar LhsScalar;
|
||||
typedef typename Rhs::Scalar RhsScalar;
|
||||
typedef typename Dest::Scalar ResScalar;
|
||||
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
|
||||
|
||||
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
|
||||
typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
|
||||
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
|
||||
typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
|
||||
|
||||
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
|
||||
* RhsBlasTraits::extractScalarFactor(prod.rhs());
|
||||
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(rhs);
|
||||
|
||||
enum {
|
||||
DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
|
||||
DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1
|
||||
};
|
||||
|
||||
gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
|
||||
gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
|
||||
|
||||
ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
|
||||
DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
|
||||
@ -325,7 +358,7 @@ template<> struct trmv_selector<RowMajor>
|
||||
int size = actualRhs.size();
|
||||
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||
#endif
|
||||
Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
|
||||
Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
|
||||
}
|
||||
|
||||
internal::triangular_matrix_vector_product
|
||||
|
Loading…
Reference in New Issue
Block a user