diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h index bf7ef54b5..6906aa75d 100644 --- a/Eigen/src/Core/GeneralProduct.h +++ b/Eigen/src/Core/GeneralProduct.h @@ -228,8 +228,7 @@ template<> struct gemv_dense_selector ActualLhsType actualLhs = LhsBlasTraits::extract(lhs); ActualRhsType actualRhs = RhsBlasTraits::extract(rhs); - ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs) - * RhsBlasTraits::extractScalarFactor(rhs); + ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs); // make sure Dest is a compile-time vector type (bug 1166) typedef typename conditional::type ActualDest; @@ -320,8 +319,7 @@ template<> struct gemv_dense_selector typename add_const::type actualLhs = LhsBlasTraits::extract(lhs); typename add_const::type actualRhs = RhsBlasTraits::extract(rhs); - ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs) - * RhsBlasTraits::extractScalarFactor(rhs); + ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs); enum { // FIXME find a way to allow an inner stride on the result if packet_traits::size==1 diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 02b58438c..079189a10 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -441,8 +441,8 @@ struct generic_product_impl }; // FIXME: in c++11 this should be auto, and extractScalarFactor should also return auto // this is important for real*complex_mat - Scalar actualAlpha = blas_traits::extractScalarFactor(lhs) - * blas_traits::extractScalarFactor(rhs); + Scalar actualAlpha = combine_scalar_factors(lhs, rhs); + eval_dynamic_impl(dst, blas_traits::extract(lhs).template conjugateIf(), blas_traits::extract(rhs).template conjugateIf(), diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 0d55bdf9e..caa65fccc 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -489,8 +489,7 @@ struct generic_product_impl typename internal::add_const_on_value_type::type lhs = LhsBlasTraits::extract(a_lhs); typename internal::add_const_on_value_type::type rhs = RhsBlasTraits::extract(a_rhs); - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) - * RhsBlasTraits::extractScalarFactor(a_rhs); + Scalar actualAlpha = combine_scalar_factors(alpha, a_lhs, a_rhs); typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar, Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType; diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index a90e57446..c5161022c 100755 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -618,6 +618,47 @@ template const typename T::Scalar* extract_data(const T& m) return extract_data_selector::run(m); } +/** + * \c combine_scalar_factors extracts and multiplies factors from GEMM and GEMV products. + * There is a specialization for booleans + */ +template +struct combine_scalar_factors_impl +{ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const Lhs& lhs, const Rhs& rhs) + { + return blas_traits::extractScalarFactor(lhs) * blas_traits::extractScalarFactor(rhs); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs) + { + return alpha * blas_traits::extractScalarFactor(lhs) * blas_traits::extractScalarFactor(rhs); + } +}; +template +struct combine_scalar_factors_impl +{ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const Lhs& lhs, const Rhs& rhs) + { + return blas_traits::extractScalarFactor(lhs) && blas_traits::extractScalarFactor(rhs); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const bool& alpha, const Lhs& lhs, const Rhs& rhs) + { + return alpha && blas_traits::extractScalarFactor(lhs) && blas_traits::extractScalarFactor(rhs); + } +}; + +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs) +{ + return combine_scalar_factors_impl::run(alpha, lhs, rhs); +} +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const Lhs& lhs, const Rhs& rhs) +{ + return combine_scalar_factors_impl::run(lhs, rhs); +} + + } // end namespace internal } // end namespace Eigen diff --git a/test/product_small.cpp b/test/product_small.cpp index 93876dba4..1d6df6e58 100644 --- a/test/product_small.cpp +++ b/test/product_small.cpp @@ -56,18 +56,17 @@ test_lazy_single(int rows, int cols, int depth) VERIFY_IS_APPROX(C+=A.lazyProduct(B), ref_prod(D,A,B)); } -template -void test_dynamic_exact() +void test_dynamic_bool() { int rows = internal::random(1,64); int cols = internal::random(1,64); int depth = internal::random(1,65); - typedef Matrix MatrixX; + typedef Matrix MatrixX; MatrixX A(rows,depth); A.setRandom(); MatrixX B(depth,cols); B.setRandom(); - MatrixX C(rows,cols); C.setRandom(); - MatrixX D(C); + MatrixX C(rows,cols); C.setRandom(); + MatrixX D(C); for(Index i=0;i() ); CALL_SUBTEST_6( bug_1311<5>() ); - CALL_SUBTEST_9( test_dynamic_exact() ); + CALL_SUBTEST_9( test_dynamic_bool() ); } CALL_SUBTEST_6( product_small_regressions<0>() );