mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-02-23 18:20:47 +08:00
blas level2: gemv and trsv are green
This commit is contained in:
parent
3fdea699b8
commit
0e30c4ae3f
@ -30,7 +30,7 @@ namespace internal {
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
|
||||
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheRight, Mode, Conjugate, StorageOrder>
|
||||
{
|
||||
static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
|
||||
static void run(Index size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
|
||||
{
|
||||
triangular_solve_vector<LhsScalar,RhsScalar,Index,OnTheLeft,
|
||||
((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag),
|
||||
@ -46,7 +46,7 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Con
|
||||
enum {
|
||||
IsLower = ((Mode&Lower)==Lower)
|
||||
};
|
||||
static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
|
||||
static void run(Index size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
|
||||
{
|
||||
typedef Map<Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
|
||||
const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
|
||||
@ -100,7 +100,7 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Con
|
||||
enum {
|
||||
IsLower = ((Mode&Lower)==Lower)
|
||||
};
|
||||
static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
|
||||
static void run(Index size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
|
||||
{
|
||||
typedef Map<Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
|
||||
const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
|
||||
|
@ -24,8 +24,39 @@
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#define MAKE_ACTUAL_VECTOR(X,INCX,N,COND) \
|
||||
Scalar* actual_##X = X; \
|
||||
if(COND) { \
|
||||
actual_##X = new Scalar[N]; \
|
||||
if((INCX)<0) vector(actual_##X,(N)) = vector(X,(N),-(INCX)).reverse(); \
|
||||
else vector(actual_##X,(N)) = vector(X,(N), (INCX)); \
|
||||
}
|
||||
|
||||
#define RELEASE_ACTUAL_VECTOR(X,INCX,N,COND) \
|
||||
if(COND) { \
|
||||
if((INCX)<0) vector(X,(N),-(INCX)).reverse() = vector(actual_##X,(N)); \
|
||||
else vector(X,(N), (INCX)) = vector(actual_##X,(N)); \
|
||||
delete[] actual_##X; \
|
||||
}
|
||||
|
||||
int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *incb, RealScalar *pbeta, RealScalar *pc, int *incc)
|
||||
{
|
||||
typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int , Scalar *, int, Scalar);
|
||||
static functype func[4];
|
||||
|
||||
static bool init = false;
|
||||
if(!init)
|
||||
{
|
||||
for(int k=0; k<4; ++k)
|
||||
func[k] = 0;
|
||||
|
||||
func[NOTR] = (internal::general_matrix_vector_product<int,Scalar,ColMajor,false,Scalar,false>::run);
|
||||
func[TR ] = (internal::general_matrix_vector_product<int,Scalar,RowMajor,false,Scalar,false>::run);
|
||||
func[ADJ ] = (internal::general_matrix_vector_product<int,Scalar,RowMajor,Conj, Scalar,false>::run);
|
||||
|
||||
init = true;
|
||||
}
|
||||
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||
Scalar* c = reinterpret_cast<Scalar*>(pc);
|
||||
@ -34,9 +65,7 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
|
||||
|
||||
// check arguments
|
||||
int info = 0;
|
||||
if( OP(*opa)!=NOTR
|
||||
&& OP(*opa)!=TR
|
||||
&& OP(*opa)!=ADJ) info = 1;
|
||||
if(OP(*opa)==INVALID) info = 1;
|
||||
else if(*m<0) info = 2;
|
||||
else if(*n<0) info = 3;
|
||||
else if(*lda<std::max(1,*m)) info = 6;
|
||||
@ -44,39 +73,34 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
|
||||
else if(*incc==0) info = 11;
|
||||
if(info)
|
||||
return xerbla_(SCALAR_SUFFIX_UP"GEMV ",&info,6);
|
||||
// return xerbla_("SGEMV ",&info,sizeof("SGEMV "));
|
||||
|
||||
if(*m==0 || *n==0)
|
||||
return 0;
|
||||
|
||||
int actual_m = *m;
|
||||
int actual_n = *n;
|
||||
if(OP(*opa)!=NOTR)
|
||||
std::swap(actual_m,actual_n);
|
||||
|
||||
MAKE_ACTUAL_VECTOR(b,*incb,actual_n,*incb!=1)
|
||||
MAKE_ACTUAL_VECTOR(c,*incc,actual_m,*incc!=1)
|
||||
|
||||
if(beta!=Scalar(1))
|
||||
vector(c, *m, *incc) *= beta;
|
||||
vector(actual_c, actual_m, 1) *= beta;
|
||||
|
||||
if(OP(*opa)==NOTR)
|
||||
if(*incc==1)
|
||||
vector(c,*m) += alpha * matrix(a,*m,*n,*lda) * vector(b,*n,*incb);
|
||||
else
|
||||
vector(c,*m,*incc) += alpha * matrix(a,*m,*n,*lda) * vector(b,*n,*incb);
|
||||
else if(OP(*opa)==TR)
|
||||
if(*incb==1)
|
||||
vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).transpose() * vector(b,*n);
|
||||
else
|
||||
vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).transpose() * vector(b,*n,*incb);
|
||||
else if(OP(*opa)==TR)
|
||||
if(*incb==1)
|
||||
vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).adjoint() * vector(b,*n);
|
||||
else
|
||||
vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).adjoint() * vector(b,*n,*incb);
|
||||
else
|
||||
return 0;
|
||||
int code = OP(*opa);
|
||||
func[code](actual_m, actual_n, a, *lda, actual_b, 1, actual_c, 1, alpha);
|
||||
|
||||
RELEASE_ACTUAL_VECTOR(b,*incb,actual_n,*incb!=1)
|
||||
RELEASE_ACTUAL_VECTOR(c,*incc,actual_m,*incc!=1)
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb)
|
||||
{
|
||||
return 0;
|
||||
|
||||
typedef void (*functype)(int, const Scalar *, int, Scalar *, int);
|
||||
functype func[16];
|
||||
typedef void (*functype)(int, const Scalar *, int, Scalar *);
|
||||
static functype func[16];
|
||||
|
||||
static bool init = false;
|
||||
if(!init)
|
||||
@ -84,21 +108,21 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar
|
||||
for(int k=0; k<16; ++k)
|
||||
func[k] = 0;
|
||||
|
||||
// func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|0, false,ColMajor,ColMajor>::run);
|
||||
// func[TR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|0, false,RowMajor,ColMajor>::run);
|
||||
// func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|0, Conj, RowMajor,ColMajor>::run);
|
||||
//
|
||||
// func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|0, false,ColMajor,ColMajor>::run);
|
||||
// func[TR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|0, false,RowMajor,ColMajor>::run);
|
||||
// func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|0, Conj, RowMajor,ColMajor>::run);
|
||||
//
|
||||
// func[NOTR | (UP << 3) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|UnitDiagBit,false,ColMajor,ColMajor>::run);
|
||||
// func[TR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|UnitDiagBit,false,RowMajor,ColMajor>::run);
|
||||
// func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|UnitDiagBit,Conj, RowMajor,ColMajor>::run);
|
||||
//
|
||||
// func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|UnitDiagBit,false,ColMajor,ColMajor>::run);
|
||||
// func[TR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|UnitDiagBit,false,RowMajor,ColMajor>::run);
|
||||
// func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|UnitDiagBit,Conj, RowMajor,ColMajor>::run);
|
||||
func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|0, false,ColMajor>::run);
|
||||
func[TR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|0, false,RowMajor>::run);
|
||||
func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|0, Conj, RowMajor>::run);
|
||||
|
||||
func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|0, false,ColMajor>::run);
|
||||
func[TR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|0, false,RowMajor>::run);
|
||||
func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|0, Conj, RowMajor>::run);
|
||||
|
||||
func[NOTR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|UnitDiag,false,ColMajor>::run);
|
||||
func[TR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|UnitDiag,false,RowMajor>::run);
|
||||
func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|UnitDiag,Conj, RowMajor>::run);
|
||||
|
||||
func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|UnitDiag,false,ColMajor>::run);
|
||||
func[TR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|UnitDiag,false,RowMajor>::run);
|
||||
func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|UnitDiag,Conj, RowMajor>::run);
|
||||
|
||||
init = true;
|
||||
}
|
||||
@ -106,11 +130,23 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||
|
||||
int code = OP(*opa) | (UPLO(*uplo) << 2) | (DIAG(*diag) << 3);
|
||||
if(code>=16 || func[code]==0)
|
||||
return 0;
|
||||
int info = 0;
|
||||
if(UPLO(*uplo)==INVALID) info = 1;
|
||||
else if(OP(*opa)==INVALID) info = 2;
|
||||
else if(DIAG(*diag)==INVALID) info = 3;
|
||||
else if(*n<0) info = 4;
|
||||
else if(*lda<std::max(1,*n)) info = 6;
|
||||
else if(*incb==0) info = 8;
|
||||
if(info)
|
||||
return xerbla_(SCALAR_SUFFIX_UP"TRSV ",&info,6);
|
||||
|
||||
MAKE_ACTUAL_VECTOR(b,*incb,*n,*incb!=1)
|
||||
|
||||
int code = OP(*opa) | (UPLO(*uplo) << 2) | (DIAG(*diag) << 3);
|
||||
func[code](*n, a, *lda, actual_b);
|
||||
|
||||
RELEASE_ACTUAL_VECTOR(b,*incb,*n,*incb!=1)
|
||||
|
||||
func[code](*n, a, *lda, b, *incb);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user