blas level2: gemv and trsv are green

This commit is contained in:
Gael Guennebaud 2010-11-05 14:14:50 +01:00
parent 3fdea699b8
commit 0e30c4ae3f
2 changed files with 86 additions and 50 deletions

View File

@ -30,7 +30,7 @@ namespace internal {
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheRight, Mode, Conjugate, StorageOrder>
{
static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
static void run(Index size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
{
triangular_solve_vector<LhsScalar,RhsScalar,Index,OnTheLeft,
((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag),
@ -46,7 +46,7 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Con
enum {
IsLower = ((Mode&Lower)==Lower)
};
static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
static void run(Index size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
{
typedef Map<Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
@ -100,7 +100,7 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Con
enum {
IsLower = ((Mode&Lower)==Lower)
};
static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
static void run(Index size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
{
typedef Map<Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));

View File

@ -24,8 +24,39 @@
#include "common.h"
#define MAKE_ACTUAL_VECTOR(X,INCX,N,COND) \
Scalar* actual_##X = X; \
if(COND) { \
actual_##X = new Scalar[N]; \
if((INCX)<0) vector(actual_##X,(N)) = vector(X,(N),-(INCX)).reverse(); \
else vector(actual_##X,(N)) = vector(X,(N), (INCX)); \
}
#define RELEASE_ACTUAL_VECTOR(X,INCX,N,COND) \
if(COND) { \
if((INCX)<0) vector(X,(N),-(INCX)).reverse() = vector(actual_##X,(N)); \
else vector(X,(N), (INCX)) = vector(actual_##X,(N)); \
delete[] actual_##X; \
}
int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *incb, RealScalar *pbeta, RealScalar *pc, int *incc)
{
typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int , Scalar *, int, Scalar);
static functype func[4];
static bool init = false;
if(!init)
{
for(int k=0; k<4; ++k)
func[k] = 0;
func[NOTR] = (internal::general_matrix_vector_product<int,Scalar,ColMajor,false,Scalar,false>::run);
func[TR ] = (internal::general_matrix_vector_product<int,Scalar,RowMajor,false,Scalar,false>::run);
func[ADJ ] = (internal::general_matrix_vector_product<int,Scalar,RowMajor,Conj, Scalar,false>::run);
init = true;
}
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc);
@ -34,9 +65,7 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
// check arguments
int info = 0;
if( OP(*opa)!=NOTR
&& OP(*opa)!=TR
&& OP(*opa)!=ADJ) info = 1;
if(OP(*opa)==INVALID) info = 1;
else if(*m<0) info = 2;
else if(*n<0) info = 3;
else if(*lda<std::max(1,*m)) info = 6;
@ -44,39 +73,34 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
else if(*incc==0) info = 11;
if(info)
return xerbla_(SCALAR_SUFFIX_UP"GEMV ",&info,6);
// return xerbla_("SGEMV ",&info,sizeof("SGEMV "));
if(beta!=Scalar(1))
vector(c, *m, *incc) *= beta;
if(OP(*opa)==NOTR)
if(*incc==1)
vector(c,*m) += alpha * matrix(a,*m,*n,*lda) * vector(b,*n,*incb);
else
vector(c,*m,*incc) += alpha * matrix(a,*m,*n,*lda) * vector(b,*n,*incb);
else if(OP(*opa)==TR)
if(*incb==1)
vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).transpose() * vector(b,*n);
else
vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).transpose() * vector(b,*n,*incb);
else if(OP(*opa)==TR)
if(*incb==1)
vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).adjoint() * vector(b,*n);
else
vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).adjoint() * vector(b,*n,*incb);
else
if(*m==0 || *n==0)
return 0;
int actual_m = *m;
int actual_n = *n;
if(OP(*opa)!=NOTR)
std::swap(actual_m,actual_n);
MAKE_ACTUAL_VECTOR(b,*incb,actual_n,*incb!=1)
MAKE_ACTUAL_VECTOR(c,*incc,actual_m,*incc!=1)
if(beta!=Scalar(1))
vector(actual_c, actual_m, 1) *= beta;
int code = OP(*opa);
func[code](actual_m, actual_n, a, *lda, actual_b, 1, actual_c, 1, alpha);
RELEASE_ACTUAL_VECTOR(b,*incb,actual_n,*incb!=1)
RELEASE_ACTUAL_VECTOR(c,*incc,actual_m,*incc!=1)
return 1;
}
int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb)
{
return 0;
typedef void (*functype)(int, const Scalar *, int, Scalar *, int);
functype func[16];
typedef void (*functype)(int, const Scalar *, int, Scalar *);
static functype func[16];
static bool init = false;
if(!init)
@ -84,21 +108,21 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar
for(int k=0; k<16; ++k)
func[k] = 0;
// func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|0, false,ColMajor,ColMajor>::run);
// func[TR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|0, false,RowMajor,ColMajor>::run);
// func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|0, Conj, RowMajor,ColMajor>::run);
//
// func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|0, false,ColMajor,ColMajor>::run);
// func[TR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|0, false,RowMajor,ColMajor>::run);
// func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|0, Conj, RowMajor,ColMajor>::run);
//
// func[NOTR | (UP << 3) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|UnitDiagBit,false,ColMajor,ColMajor>::run);
// func[TR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|UnitDiagBit,false,RowMajor,ColMajor>::run);
// func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|UnitDiagBit,Conj, RowMajor,ColMajor>::run);
//
// func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|UnitDiagBit,false,ColMajor,ColMajor>::run);
// func[TR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|UnitDiagBit,false,RowMajor,ColMajor>::run);
// func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|UnitDiagBit,Conj, RowMajor,ColMajor>::run);
func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|0, false,ColMajor>::run);
func[TR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|0, false,RowMajor>::run);
func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|0, Conj, RowMajor>::run);
func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|0, false,ColMajor>::run);
func[TR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|0, false,RowMajor>::run);
func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|0, Conj, RowMajor>::run);
func[NOTR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|UnitDiag,false,ColMajor>::run);
func[TR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|UnitDiag,false,RowMajor>::run);
func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|UnitDiag,Conj, RowMajor>::run);
func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|UnitDiag,false,ColMajor>::run);
func[TR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|UnitDiag,false,RowMajor>::run);
func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|UnitDiag,Conj, RowMajor>::run);
init = true;
}
@ -106,11 +130,23 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
int code = OP(*opa) | (UPLO(*uplo) << 2) | (DIAG(*diag) << 3);
if(code>=16 || func[code]==0)
return 0;
int info = 0;
if(UPLO(*uplo)==INVALID) info = 1;
else if(OP(*opa)==INVALID) info = 2;
else if(DIAG(*diag)==INVALID) info = 3;
else if(*n<0) info = 4;
else if(*lda<std::max(1,*n)) info = 6;
else if(*incb==0) info = 8;
if(info)
return xerbla_(SCALAR_SUFFIX_UP"TRSV ",&info,6);
func[code](*n, a, *lda, b, *incb);
MAKE_ACTUAL_VECTOR(b,*incb,*n,*incb!=1)
int code = OP(*opa) | (UPLO(*uplo) << 2) | (DIAG(*diag) << 3);
func[code](*n, a, *lda, actual_b);
RELEASE_ACTUAL_VECTOR(b,*incb,*n,*incb!=1)
return 0;
}