From 0e30c4ae3f35523302f1449f67a2be714e30beb8 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Fri, 5 Nov 2010 14:14:50 +0100 Subject: [PATCH] blas level2: gemv and trsv are green --- .../Core/products/TriangularSolverVector.h | 6 +- blas/level2_impl.h | 130 +++++++++++------- 2 files changed, 86 insertions(+), 50 deletions(-) diff --git a/Eigen/src/Core/products/TriangularSolverVector.h b/Eigen/src/Core/products/TriangularSolverVector.h index 25e739178..db1e7f029 100644 --- a/Eigen/src/Core/products/TriangularSolverVector.h +++ b/Eigen/src/Core/products/TriangularSolverVector.h @@ -30,7 +30,7 @@ namespace internal { template struct triangular_solve_vector { - static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs) + static void run(Index size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs) { triangular_solve_vector, 0, OuterStride<> > LhsMap; const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride)); @@ -100,7 +100,7 @@ struct triangular_solve_vector, 0, OuterStride<> > LhsMap; const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride)); diff --git a/blas/level2_impl.h b/blas/level2_impl.h index 2749cf5b3..55851ddb3 100644 --- a/blas/level2_impl.h +++ b/blas/level2_impl.h @@ -24,8 +24,39 @@ #include "common.h" +#define MAKE_ACTUAL_VECTOR(X,INCX,N,COND) \ + Scalar* actual_##X = X; \ + if(COND) { \ + actual_##X = new Scalar[N]; \ + if((INCX)<0) vector(actual_##X,(N)) = vector(X,(N),-(INCX)).reverse(); \ + else vector(actual_##X,(N)) = vector(X,(N), (INCX)); \ + } + +#define RELEASE_ACTUAL_VECTOR(X,INCX,N,COND) \ + if(COND) { \ + if((INCX)<0) vector(X,(N),-(INCX)).reverse() = vector(actual_##X,(N)); \ + else vector(X,(N), (INCX)) = vector(actual_##X,(N)); \ + delete[] actual_##X; \ + } + int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *incb, RealScalar *pbeta, RealScalar *pc, int *incc) { + typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int , Scalar *, int, Scalar); + static functype func[4]; + + static bool init = false; + if(!init) + { + for(int k=0; k<4; ++k) + func[k] = 0; + + func[NOTR] = (internal::general_matrix_vector_product::run); + func[TR ] = (internal::general_matrix_vector_product::run); + func[ADJ ] = (internal::general_matrix_vector_product::run); + + init = true; + } + Scalar* a = reinterpret_cast(pa); Scalar* b = reinterpret_cast(pb); Scalar* c = reinterpret_cast(pc); @@ -34,9 +65,7 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca // check arguments int info = 0; - if( OP(*opa)!=NOTR - && OP(*opa)!=TR - && OP(*opa)!=ADJ) info = 1; + if(OP(*opa)==INVALID) info = 1; else if(*m<0) info = 2; else if(*n<0) info = 3; else if(*lda::run); -// func[TR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector::run); -// func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector::run); -// -// func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector::run); -// func[TR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector::run); -// func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector::run); -// -// func[NOTR | (UP << 3) | (UNIT << 3)] = (internal::triangular_solve_vector::run); -// func[TR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector::run); -// func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector::run); -// -// func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector::run); -// func[TR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector::run); -// func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector::run); + func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector::run); + func[TR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector::run); + func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector::run); + + func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector::run); + func[TR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector::run); + func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector::run); + + func[NOTR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector::run); + func[TR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector::run); + func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector::run); + + func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector::run); + func[TR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector::run); + func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector::run); init = true; } @@ -106,11 +130,23 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar Scalar* a = reinterpret_cast(pa); Scalar* b = reinterpret_cast(pb); - int code = OP(*opa) | (UPLO(*uplo) << 2) | (DIAG(*diag) << 3); - if(code>=16 || func[code]==0) - return 0; + int info = 0; + if(UPLO(*uplo)==INVALID) info = 1; + else if(OP(*opa)==INVALID) info = 2; + else if(DIAG(*diag)==INVALID) info = 3; + else if(*n<0) info = 4; + else if(*lda