mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-11-27 06:30:28 +08:00
fix openmp version for scalar types different than float
This commit is contained in:
parent
d13b877014
commit
62ac021606
@ -40,7 +40,7 @@ struct ei_general_matrix_matrix_product<Scalar,LhsStorageOrder,ConjugateLhs,RhsS
|
||||
const Scalar* rhs, int rhsStride,
|
||||
Scalar* res, int resStride,
|
||||
Scalar alpha,
|
||||
GemmParallelInfo* info = 0)
|
||||
GemmParallelInfo<Scalar>* info = 0)
|
||||
{
|
||||
// transpose the product such that the result is column major
|
||||
ei_general_matrix_matrix_product<Scalar,
|
||||
@ -66,7 +66,7 @@ static void run(int rows, int cols, int depth,
|
||||
const Scalar* _rhs, int rhsStride,
|
||||
Scalar* res, int resStride,
|
||||
Scalar alpha,
|
||||
GemmParallelInfo* info = 0)
|
||||
GemmParallelInfo<Scalar>* info = 0)
|
||||
{
|
||||
ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
|
||||
ei_const_blas_data_mapper<Scalar, RhsStorageOrder> rhs(_rhs,rhsStride);
|
||||
@ -218,11 +218,13 @@ struct ei_traits<GeneralProduct<Lhs,Rhs,GemmProduct> >
|
||||
template<typename Scalar, typename Gemm, typename Lhs, typename Rhs, typename Dest>
|
||||
struct ei_gemm_functor
|
||||
{
|
||||
typedef typename Rhs::Scalar BlockBScalar;
|
||||
|
||||
ei_gemm_functor(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha)
|
||||
: m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha)
|
||||
{}
|
||||
|
||||
void operator() (int row, int rows, int col=0, int cols=-1, GemmParallelInfo* info=0) const
|
||||
void operator() (int row, int rows, int col=0, int cols=-1, GemmParallelInfo<BlockBScalar>* info=0) const
|
||||
{
|
||||
if(cols==-1)
|
||||
cols = m_rhs.cols();
|
||||
@ -234,6 +236,12 @@ struct ei_gemm_functor
|
||||
info);
|
||||
}
|
||||
|
||||
|
||||
int sharedBlockBSize() const
|
||||
{
|
||||
return std::min<int>(ei_product_blocking_traits<Scalar>::Max_kc,m_rhs.rows()) * m_rhs.cols();
|
||||
}
|
||||
|
||||
protected:
|
||||
const Lhs& m_lhs;
|
||||
const Rhs& m_rhs;
|
||||
@ -275,7 +283,7 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
|
||||
_ActualRhsType,
|
||||
Dest> GemmFunctor;
|
||||
|
||||
ei_parallelize_gemm<Dest::MaxRowsAtCompileTime>32>(GemmFunctor(lhs, rhs, dst, actualAlpha), this->rows(), this->cols());
|
||||
ei_parallelize_gemm<(Dest::MaxRowsAtCompileTime>32)>(GemmFunctor(lhs, rhs, dst, actualAlpha), this->rows(), this->cols());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -25,16 +25,16 @@
|
||||
#ifndef EIGEN_PARALLELIZER_H
|
||||
#define EIGEN_PARALLELIZER_H
|
||||
|
||||
struct GemmParallelInfo
|
||||
template<typename BlockBScalar> struct GemmParallelInfo
|
||||
{
|
||||
GemmParallelInfo() : sync(-1), users(0) {}
|
||||
GemmParallelInfo() : sync(-1), users(0), rhs_start(0), rhs_length(0), blockB(0) {}
|
||||
|
||||
int volatile sync;
|
||||
int volatile users;
|
||||
|
||||
int rhs_start;
|
||||
int rhs_length;
|
||||
float* blockB;
|
||||
BlockBScalar* blockB;
|
||||
};
|
||||
|
||||
template<bool Condition,typename Functor>
|
||||
@ -51,9 +51,10 @@ void ei_parallelize_gemm(const Functor& func, int rows, int cols)
|
||||
int blockCols = (cols / threads) & ~0x3;
|
||||
int blockRows = (rows / threads) & ~0x7;
|
||||
|
||||
float* sharedBlockB = new float[2048*2048*4];
|
||||
typedef typename Functor::BlockBScalar BlockBScalar;
|
||||
BlockBScalar* sharedBlockB = new BlockBScalar[func.sharedBlockBSize()];
|
||||
|
||||
GemmParallelInfo* info = new GemmParallelInfo[threads];
|
||||
GemmParallelInfo<BlockBScalar>* info = new GemmParallelInfo<BlockBScalar>[threads];
|
||||
|
||||
#pragma omp parallel for schedule(static,1)
|
||||
for(int i=0; i<threads; ++i)
|
||||
|
Loading…
Reference in New Issue
Block a user