mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-15 07:10:37 +08:00
Merged in rmlarsen/eigen (pull request PR-578)
Speed up Eigen matrix*vector and vector*matrix multiplication. Approved-by: Eugene Zhulenev <ezhulenev@google.com>
This commit is contained in:
commit
e7b481ea74
@ -404,13 +404,13 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M
|
||||
|
||||
namespace internal {
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
|
||||
: generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> >
|
||||
{
|
||||
template <typename Lhs, typename Rhs, typename Dest,
|
||||
bool MultipleRowsAtCompileTime =
|
||||
(Lhs::RowsAtCompileTime > 1 || Dest::RowsAtCompileTime > 1),
|
||||
bool MultipleColsAtCompileTime =
|
||||
(Rhs::ColsAtCompileTime > 1 || Dest::ColsAtCompileTime > 1)>
|
||||
struct gemm_selector {
|
||||
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
||||
typedef typename Lhs::Scalar LhsScalar;
|
||||
typedef typename Rhs::Scalar RhsScalar;
|
||||
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
@ -420,10 +420,130 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
|
||||
|
||||
static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha)
|
||||
{
|
||||
if (a_rhs.cols() != 1 && a_lhs.rows() != 1) {
|
||||
gemm_selector<Lhs, Rhs, Dest, true, true>::run(dst, a_lhs, a_rhs, alpha);
|
||||
} else if (a_rhs.cols() == 1) {
|
||||
// matrix * vector.
|
||||
internal::gemv_dense_selector<OnTheRight,
|
||||
(int(ActualLhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor,
|
||||
bool(internal::blas_traits<ActualLhsTypeCleaned>::HasUsableDirectAccess)
|
||||
>::run(a_lhs, a_rhs.col(0), dst, alpha);
|
||||
} else {
|
||||
// vector * matrix.
|
||||
internal::gemv_dense_selector<OnTheLeft,
|
||||
(int(ActualRhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor,
|
||||
bool(internal::blas_traits<ActualRhsTypeCleaned>::HasUsableDirectAccess)
|
||||
>::run(a_lhs.row(0), a_rhs, dst, alpha);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Lhs, typename Rhs, typename Dest>
|
||||
struct gemm_selector<Lhs, Rhs, Dest, true, false> {
|
||||
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
||||
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
|
||||
|
||||
static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha)
|
||||
{
|
||||
if (a_rhs.cols() != 1 && a_lhs.rows() != 1) {
|
||||
gemm_selector<Lhs, Rhs, Dest, true, true>::run(dst, a_lhs, a_rhs, alpha);
|
||||
} else {
|
||||
// matrix * vector.
|
||||
internal::gemv_dense_selector<OnTheRight,
|
||||
(int(ActualLhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor,
|
||||
bool(internal::blas_traits<ActualLhsTypeCleaned>::HasUsableDirectAccess)
|
||||
>::run(a_lhs, a_rhs.col(0), dst, alpha);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Lhs, typename Rhs, typename Dest>
|
||||
struct gemm_selector<Lhs, Rhs, Dest, false, true> {
|
||||
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
||||
|
||||
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
|
||||
|
||||
static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha)
|
||||
{
|
||||
if (a_rhs.cols() != 1 && a_lhs.rows() != 1) {
|
||||
gemm_selector<Lhs, Rhs, Dest, true, true>::run(dst, a_lhs, a_rhs, alpha);
|
||||
} else {
|
||||
// vector * matrix.
|
||||
internal::gemv_dense_selector<OnTheLeft,
|
||||
(int(ActualRhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor,
|
||||
bool(internal::blas_traits<ActualRhsTypeCleaned>::HasUsableDirectAccess)
|
||||
>::run(a_lhs.row(0), a_rhs, dst, alpha);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Lhs, typename Rhs, typename Dest>
|
||||
struct gemm_selector<Lhs, Rhs, Dest, true, true> {
|
||||
typedef typename Product<Lhs, Rhs>::Scalar Scalar;
|
||||
typedef typename Lhs::Scalar LhsScalar;
|
||||
typedef typename Rhs::Scalar RhsScalar;
|
||||
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef
|
||||
typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
|
||||
|
||||
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
typedef
|
||||
typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
|
||||
|
||||
enum {
|
||||
MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime)
|
||||
MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(
|
||||
Lhs::MaxColsAtCompileTime, Rhs::MaxRowsAtCompileTime)
|
||||
};
|
||||
|
||||
static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs,
|
||||
const Scalar& alpha) {
|
||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) *
|
||||
RhsBlasTraits::extractScalarFactor(a_rhs);
|
||||
typename internal::add_const_on_value_type<ActualLhsType>::type lhs =
|
||||
LhsBlasTraits::extract(a_lhs);
|
||||
typename internal::add_const_on_value_type<ActualRhsType>::type rhs =
|
||||
RhsBlasTraits::extract(a_rhs);
|
||||
typedef internal::gemm_blocking_space<
|
||||
(Dest::Flags & RowMajorBit) ? RowMajor : ColMajor, LhsScalar, RhsScalar,
|
||||
Dest::MaxRowsAtCompileTime, Dest::MaxColsAtCompileTime,
|
||||
MaxDepthAtCompileTime>
|
||||
BlockingType;
|
||||
|
||||
typedef internal::gemm_functor<
|
||||
Scalar, Index,
|
||||
internal::general_matrix_matrix_product<
|
||||
Index, LhsScalar,
|
||||
(ActualLhsTypeCleaned::Flags & RowMajorBit) ? RowMajor : ColMajor,
|
||||
bool(LhsBlasTraits::NeedToConjugate), RhsScalar,
|
||||
(ActualRhsTypeCleaned::Flags & RowMajorBit) ? RowMajor : ColMajor,
|
||||
bool(RhsBlasTraits::NeedToConjugate),
|
||||
(Dest::Flags & RowMajorBit) ? RowMajor : ColMajor>,
|
||||
ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType>
|
||||
GemmFunctor;
|
||||
|
||||
BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true);
|
||||
internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime > 32 ||
|
||||
Dest::MaxRowsAtCompileTime == Dynamic)>(
|
||||
GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(),
|
||||
a_rhs.cols(), a_lhs.cols(), Dest::Flags & RowMajorBit);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
|
||||
: generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> >
|
||||
{
|
||||
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
||||
typedef generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> lazyproduct;
|
||||
|
||||
template<typename Dst>
|
||||
@ -450,7 +570,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
|
||||
if((rhs.rows()+dst.rows()+dst.cols())<EIGEN_GEMM_TO_COEFFBASED_THRESHOLD && rhs.rows()>0)
|
||||
lazyproduct::eval_dynamic(dst, lhs, rhs, internal::add_assign_op<typename Dst::Scalar,Scalar>());
|
||||
else
|
||||
scaleAndAddTo(dst,lhs, rhs, Scalar(1));
|
||||
scaleAndAddTo(dst, lhs, rhs, Scalar(1));
|
||||
}
|
||||
|
||||
template<typename Dst>
|
||||
@ -469,27 +589,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
|
||||
if(a_lhs.cols()==0 || a_lhs.rows()==0 || a_rhs.cols()==0)
|
||||
return;
|
||||
|
||||
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
|
||||
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
|
||||
|
||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(a_rhs);
|
||||
|
||||
typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar,
|
||||
Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
|
||||
|
||||
typedef internal::gemm_functor<
|
||||
Scalar, Index,
|
||||
internal::general_matrix_matrix_product<
|
||||
Index,
|
||||
LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
|
||||
RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
|
||||
(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
|
||||
ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor;
|
||||
|
||||
BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true);
|
||||
internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>
|
||||
(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), a_lhs.cols(), Dest::Flags&RowMajorBit);
|
||||
gemm_selector<Lhs, Rhs, Dest>::run(dst, a_lhs, a_rhs, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user