From 4c0fa6ce0f81ce67dd6723528ddf72f66ae92ba2 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Thu, 31 Jan 2019 14:24:08 -0800 Subject: [PATCH] Speed up Eigen matrix*vector and vector*matrix multiplication. This change speeds up Eigen matrix * vector and vector * matrix multiplication for dynamic matrices when it is known at runtime that one of the factors is a vector. The benchmarks below test c.noalias()= n_by_n_matrix * n_by_1_matrix; c.noalias()= 1_by_n_matrix * n_by_n_matrix; respectively. Benchmark measurements: SSE: Run on *** (72 X 2992 MHz CPUs); 2019-01-28T17:51:44.452697457-08:00 CPU: Intel Skylake Xeon with HyperThreading (36 cores) dL1:32KB dL2:1024KB dL3:24MB Benchmark Base (ns) New (ns) Improvement ------------------------------------------------------------------ BM_MatVec/64 1096 312 +71.5% BM_MatVec/128 4581 1464 +68.0% BM_MatVec/256 18534 5710 +69.2% BM_MatVec/512 118083 24162 +79.5% BM_MatVec/1k 704106 173346 +75.4% BM_MatVec/2k 3080828 742728 +75.9% BM_MatVec/4k 25421512 4530117 +82.2% BM_VecMat/32 352 130 +63.1% BM_VecMat/64 1213 425 +65.0% BM_VecMat/128 4640 1564 +66.3% BM_VecMat/256 17902 5884 +67.1% BM_VecMat/512 70466 24000 +65.9% BM_VecMat/1k 340150 161263 +52.6% BM_VecMat/2k 1420590 645576 +54.6% BM_VecMat/4k 8083859 4364327 +46.0% AVX2: Run on *** (72 X 2993 MHz CPUs); 2019-01-28T17:45:11.508545307-08:00 CPU: Intel Skylake Xeon with HyperThreading (36 cores) dL1:32KB dL2:1024KB dL3:24MB Benchmark Base (ns) New (ns) Improvement ------------------------------------------------------------------ BM_MatVec/64 619 120 +80.6% BM_MatVec/128 9693 752 +92.2% BM_MatVec/256 38356 2773 +92.8% BM_MatVec/512 69006 12803 +81.4% BM_MatVec/1k 443810 160378 +63.9% BM_MatVec/2k 2633553 646594 +75.4% BM_MatVec/4k 16211095 4327148 +73.3% BM_VecMat/64 925 227 +75.5% BM_VecMat/128 3438 830 +75.9% BM_VecMat/256 13427 2936 +78.1% BM_VecMat/512 53944 12473 +76.9% BM_VecMat/1k 302264 157076 +48.0% BM_VecMat/2k 1396811 675778 +51.6% BM_VecMat/4k 8962246 4459010 +50.2% AVX512: Run on *** (72 X 2993 MHz CPUs); 2019-01-28T17:35:17.239329863-08:00 CPU: Intel Skylake Xeon with HyperThreading (36 cores) dL1:32KB dL2:1024KB dL3:24MB Benchmark Base (ns) New (ns) Improvement ------------------------------------------------------------------ BM_MatVec/64 401 111 +72.3% BM_MatVec/128 1846 513 +72.2% BM_MatVec/256 36739 1927 +94.8% BM_MatVec/512 54490 9227 +83.1% BM_MatVec/1k 487374 161457 +66.9% BM_MatVec/2k 2016270 643824 +68.1% BM_MatVec/4k 13204300 4077412 +69.1% BM_VecMat/32 324 106 +67.3% BM_VecMat/64 1034 246 +76.2% BM_VecMat/128 3576 802 +77.6% BM_VecMat/256 13411 2561 +80.9% BM_VecMat/512 58686 10037 +82.9% BM_VecMat/1k 320862 163750 +49.0% BM_VecMat/2k 1406719 651397 +53.7% BM_VecMat/4k 7785179 4124677 +47.0% Currently watchingStop watching --- Eigen/src/Core/products/GeneralMatrixMatrix.h | 158 ++++++++++++++---- 1 file changed, 129 insertions(+), 29 deletions(-) diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index f49abcad5..4bcccd326 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -404,13 +404,13 @@ class gemm_blocking_space -struct generic_product_impl - : generic_product_impl_base > -{ +template 1 || Dest::RowsAtCompileTime > 1), + bool MultipleColsAtCompileTime = + (Rhs::ColsAtCompileTime > 1 || Dest::ColsAtCompileTime > 1)> +struct gemm_selector { typedef typename Product::Scalar Scalar; - typedef typename Lhs::Scalar LhsScalar; - typedef typename Rhs::Scalar RhsScalar; typedef internal::blas_traits LhsBlasTraits; typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; @@ -420,10 +420,130 @@ struct generic_product_impl typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; typedef typename internal::remove_all::type ActualRhsTypeCleaned; + static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) + { + if (a_rhs.cols() != 1 && a_lhs.rows() != 1) { + gemm_selector::run(dst, a_lhs, a_rhs, alpha); + } else if (a_rhs.cols() == 1) { + // matrix * vector. + internal::gemv_dense_selector::HasUsableDirectAccess) + >::run(a_lhs, a_rhs.col(0), dst, alpha); + } else { + // vector * matrix. + internal::gemv_dense_selector::HasUsableDirectAccess) + >::run(a_lhs.row(0), a_rhs, dst, alpha); + } + } +}; + +template +struct gemm_selector { + typedef typename Product::Scalar Scalar; + + typedef internal::blas_traits LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef typename internal::remove_all::type ActualLhsTypeCleaned; + + static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) + { + if (a_rhs.cols() != 1 && a_lhs.rows() != 1) { + gemm_selector::run(dst, a_lhs, a_rhs, alpha); + } else { + // matrix * vector. + internal::gemv_dense_selector::HasUsableDirectAccess) + >::run(a_lhs, a_rhs.col(0), dst, alpha); + } + } +}; + +template +struct gemm_selector { + typedef typename Product::Scalar Scalar; + + typedef internal::blas_traits RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef typename internal::remove_all::type ActualRhsTypeCleaned; + + static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) + { + if (a_rhs.cols() != 1 && a_lhs.rows() != 1) { + gemm_selector::run(dst, a_lhs, a_rhs, alpha); + } else { + // vector * matrix. + internal::gemv_dense_selector::HasUsableDirectAccess) + >::run(a_lhs.row(0), a_rhs, dst, alpha); + } + } +}; + +template +struct gemm_selector { + typedef typename Product::Scalar Scalar; + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; + + typedef internal::blas_traits LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef + typename internal::remove_all::type ActualLhsTypeCleaned; + + typedef internal::blas_traits RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef + typename internal::remove_all::type ActualRhsTypeCleaned; + enum { - MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime) + MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED( + Lhs::MaxColsAtCompileTime, Rhs::MaxRowsAtCompileTime) }; + static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, + const Scalar& alpha) { + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) * + RhsBlasTraits::extractScalarFactor(a_rhs); + typename internal::add_const_on_value_type::type lhs = + LhsBlasTraits::extract(a_lhs); + typename internal::add_const_on_value_type::type rhs = + RhsBlasTraits::extract(a_rhs); + typedef internal::gemm_blocking_space< + (Dest::Flags & RowMajorBit) ? RowMajor : ColMajor, LhsScalar, RhsScalar, + Dest::MaxRowsAtCompileTime, Dest::MaxColsAtCompileTime, + MaxDepthAtCompileTime> + BlockingType; + + typedef internal::gemm_functor< + Scalar, Index, + internal::general_matrix_matrix_product< + Index, LhsScalar, + (ActualLhsTypeCleaned::Flags & RowMajorBit) ? RowMajor : ColMajor, + bool(LhsBlasTraits::NeedToConjugate), RhsScalar, + (ActualRhsTypeCleaned::Flags & RowMajorBit) ? RowMajor : ColMajor, + bool(RhsBlasTraits::NeedToConjugate), + (Dest::Flags & RowMajorBit) ? RowMajor : ColMajor>, + ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> + GemmFunctor; + + BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true); + internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime > 32 || + Dest::MaxRowsAtCompileTime == Dynamic)>( + GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), + a_rhs.cols(), a_lhs.cols(), Dest::Flags & RowMajorBit); + } +}; + +template +struct generic_product_impl + : generic_product_impl_base > +{ + typedef typename Product::Scalar Scalar; typedef generic_product_impl lazyproduct; template @@ -450,7 +570,7 @@ struct generic_product_impl if((rhs.rows()+dst.rows()+dst.cols())0) lazyproduct::eval_dynamic(dst, lhs, rhs, internal::add_assign_op()); else - scaleAndAddTo(dst,lhs, rhs, Scalar(1)); + scaleAndAddTo(dst, lhs, rhs, Scalar(1)); } template @@ -469,27 +589,7 @@ struct generic_product_impl if(a_lhs.cols()==0 || a_lhs.rows()==0 || a_rhs.cols()==0) return; - typename internal::add_const_on_value_type::type lhs = LhsBlasTraits::extract(a_lhs); - typename internal::add_const_on_value_type::type rhs = RhsBlasTraits::extract(a_rhs); - - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) - * RhsBlasTraits::extractScalarFactor(a_rhs); - - typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar, - Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType; - - typedef internal::gemm_functor< - Scalar, Index, - internal::general_matrix_matrix_product< - Index, - LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), - RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), - (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, - ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor; - - BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true); - internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)> - (GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), a_lhs.cols(), Dest::Flags&RowMajorBit); + gemm_selector::run(dst, a_lhs, a_rhs, alpha); } };