blas interface: fix compilation, fix GEMM, SYMM, TRMM, and TRSM,

i,e., they all pass the blas test suite. More to come
This commit is contained in:
Gael Guennebaud 2010-03-01 19:06:07 +01:00
parent a1ac56a7c7
commit a7b9250ad0
4 changed files with 159 additions and 113 deletions

View File

@ -4,7 +4,8 @@ add_custom_target(blas)
set(EigenBlas_SRCS single.cpp double.cpp complex_single.cpp complex_double.cpp)
add_library(eigen_blas SHARED ${EigenBlas_SRCS})
add_library(eigen_blas ${EigenBlas_SRCS})
# add_library(eigen_blas SHARED ${EigenBlas_SRCS})
add_dependencies(blas eigen_blas)
install(TARGETS eigen_blas

View File

@ -25,6 +25,8 @@
#ifndef EIGEN_BLAS_COMMON_H
#define EIGEN_BLAS_COMMON_H
#include <iostream>
#ifndef SCALAR
#error the token SCALAR must be defined to compile this file
#endif
@ -34,13 +36,12 @@ extern "C"
{
#endif
#include <blas.h>
#include "../bench/btl/libs/C_BLAS/blas.h"
#ifdef __cplusplus
}
#endif
#define NOTR 0
#define TR 1
#define ADJ 2
@ -75,27 +76,6 @@ extern "C"
#include <Eigen/Jacobi>
using namespace Eigen;
template<typename T>
Block<Map<Matrix<T,Dynamic,Dynamic> >, Dynamic, Dynamic>
matrix(T* data, int rows, int cols, int stride)
{
return Map<Matrix<T,Dynamic,Dynamic> >(data, stride, cols).block(0,0,rows,cols);
}
template<typename T>
Block<Map<Matrix<T,Dynamic,Dynamic,RowMajor> >, Dynamic, 1>
vector(T* data, int size, int incr)
{
return Map<Matrix<T,Dynamic,Dynamic,RowMajor> >(data, size, incr).col(0);
}
template<typename T>
Map<Matrix<T,Dynamic,1> >
vector(T* data, int size)
{
return Map<Matrix<T,Dynamic,1> >(data, size);
}
typedef SCALAR Scalar;
typedef NumTraits<Scalar>::Real RealScalar;
typedef std::complex<RealScalar> Complex;
@ -106,10 +86,29 @@ enum
Conj = IsComplex
};
typedef Block<Map<Matrix<Scalar,Dynamic,Dynamic> >, Dynamic, Dynamic> MatrixType;
typedef Block<Map<Matrix<Scalar,Dynamic,Dynamic, RowMajor> >, Dynamic, 1> StridedVectorType;
typedef Map<Matrix<Scalar,Dynamic,Dynamic>, 0, OuterStride<Dynamic> > MatrixType;
typedef Map<Matrix<Scalar,Dynamic,1>, 0, InnerStride<Dynamic> > StridedVectorType;
typedef Map<Matrix<Scalar,Dynamic,1> > CompactVectorType;
template<typename T>
Map<Matrix<T,Dynamic,Dynamic>, 0, OuterStride<Dynamic> >
matrix(T* data, int rows, int cols, int stride)
{
return Map<Matrix<T,Dynamic,Dynamic>, 0, OuterStride<Dynamic> >(data, rows, cols, OuterStride<Dynamic>(stride));
}
template<typename T>
Map<Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> > vector(T* data, int size, int incr)
{
return Map<Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> >(data, size, InnerStride<Dynamic>(incr));
}
template<typename T>
Map<Matrix<T,Dynamic,1> > vector(T* data, int size)
{
return Map<Matrix<T,Dynamic,1> >(data, size);
}
#define EIGEN_BLAS_FUNC(X) EIGEN_CAT(SCALAR_SUFFIX,X##_)
#endif // EIGEN_BLAS_COMMON_H

View File

@ -45,9 +45,9 @@ RealScalar EIGEN_BLAS_FUNC(asum)(int *n, RealScalar *px, int *incx)
int size = IsComplex ? 2* *n : *n;
if(*incx==1)
return vector(px,size).cwise().abs().sum();
return vector(px,size).cwiseAbs().sum();
else
return vector(px,size,*incx).cwise().abs().sum();
return vector(px,size,*incx).cwiseAbs().sum();
return 1;
}
@ -71,9 +71,9 @@ Scalar EIGEN_BLAS_FUNC(dot)(int *n, RealScalar *px, int *incx, RealScalar *py, i
Scalar* y = reinterpret_cast<Scalar*>(py);
if(*incx==1 && *incy==1)
return (vector(x,*n).cwise()*vector(y,*n)).sum();
return (vector(x,*n).cwiseProduct(vector(y,*n))).sum();
return (vector(x,*n,*incx).cwise()*vector(y,*n,*incy)).sum();
return (vector(x,*n,*incx).cwiseProduct(vector(y,*n,*incy))).sum();
}
/*
@ -114,9 +114,9 @@ Scalar EIGEN_BLAS_FUNC(dotu)(int *n, RealScalar *px, int *incx, RealScalar *py,
Scalar* y = reinterpret_cast<Scalar*>(py);
if(*incx==1 && *incy==1)
return (vector(x,*n).cwise()*vector(y,*n)).sum();
return (vector(x,*n).cwiseProduct(vector(y,*n))).sum();
return (vector(x,*n,*incx).cwise()*vector(y,*n,*incy)).sum();
return (vector(x,*n,*incx).cwiseProduct(vector(y,*n,*incy))).sum();
}
#endif // ISCOMPLEX
@ -215,9 +215,9 @@ RealScalar EIGEN_BLAS_FUNC(casum)(int *n, RealScalar *px, int *incx)
Complex* x = reinterpret_cast<Complex*>(px);
if(*incx==1)
return vector(x,*n).cwise().abs().sum();
return vector(x,*n).cwiseAbs().sum();
else
return vector(x,*n,*incx).cwise().abs().sum();
return vector(x,*n,*incx).cwiseAbs().sum();
return 1;
}

View File

@ -26,8 +26,9 @@
int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
{
// std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
typedef void (*functype)(int, int, int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar);
functype func[12];
static functype func[12];
static bool init = false;
if(!init)
@ -52,21 +53,29 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
if(beta!=Scalar(1))
matrix(c, *m, *n, *ldc) *= beta;
int code = OP(*opa) | (OP(*opb) << 2);
if(code>=12 || func[code]==0)
if(code>=12 || func[code]==0 || (*m<0) || (*n<0) || (*k<0))
{
int info = 1;
xerbla_("GEMM", &info, 4);
return 0;
}
if(beta!=Scalar(1))
if(beta==Scalar(0))
matrix(c, *m, *n, *ldc).setZero();
else
matrix(c, *m, *n, *ldc) *= beta;
func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha);
return 1;
return 0;
}
int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb)
{
// std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n";
typedef void (*functype)(int, int, const Scalar *, int, Scalar *, int);
functype func[32];
static functype func[32];
static bool init = false;
if(!init)
@ -74,38 +83,38 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m,
for(int k=0; k<32; ++k)
func[k] = 0;
func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, UpperTriangular|0, false,ColMajor,ColMajor>::run);
func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, UpperTriangular|0, false,RowMajor,ColMajor>::run);
func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, UpperTriangular|0, Conj, RowMajor,ColMajor>::run);
func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, Upper|0, false,ColMajor,ColMajor>::run);
func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, Lower|0, false,RowMajor,ColMajor>::run);
func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, Lower|0, Conj, RowMajor,ColMajor>::run);
func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,UpperTriangular|0, false,ColMajor,ColMajor>::run);
func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,UpperTriangular|0, false,RowMajor,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,UpperTriangular|0, Conj, RowMajor,ColMajor>::run);
func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,Upper|0, false,ColMajor,ColMajor>::run);
func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,Lower|0, false,RowMajor,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,Lower|0, Conj, RowMajor,ColMajor>::run);
func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, LowerTriangular|0, false,ColMajor,ColMajor>::run);
func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, LowerTriangular|0, false,RowMajor,ColMajor>::run);
func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, LowerTriangular|0, Conj, RowMajor,ColMajor>::run);
func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, Lower|0, false,ColMajor,ColMajor>::run);
func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, Upper|0, false,RowMajor,ColMajor>::run);
func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, Upper|0, Conj, RowMajor,ColMajor>::run);
func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,LowerTriangular|0, false,ColMajor,ColMajor>::run);
func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,LowerTriangular|0, false,RowMajor,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,LowerTriangular|0, Conj, RowMajor,ColMajor>::run);
func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,Lower|0, false,ColMajor,ColMajor>::run);
func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,Upper|0, false,RowMajor,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,Upper|0, Conj, RowMajor,ColMajor>::run);
func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, UpperTriangular|UnitDiagBit,false,ColMajor,ColMajor>::run);
func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, UpperTriangular|UnitDiagBit,false,RowMajor,ColMajor>::run);
func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, UpperTriangular|UnitDiagBit,Conj, RowMajor,ColMajor>::run);
func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, Upper|UnitDiag,false,ColMajor,ColMajor>::run);
func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, Lower|UnitDiag,false,RowMajor,ColMajor>::run);
func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, Lower|UnitDiag,Conj, RowMajor,ColMajor>::run);
func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,UpperTriangular|UnitDiagBit,false,ColMajor,ColMajor>::run);
func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,UpperTriangular|UnitDiagBit,false,RowMajor,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,UpperTriangular|UnitDiagBit,Conj, RowMajor,ColMajor>::run);
func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,Upper|UnitDiag,false,ColMajor,ColMajor>::run);
func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,Lower|UnitDiag,false,RowMajor,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,Lower|UnitDiag,Conj, RowMajor,ColMajor>::run);
func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, LowerTriangular|UnitDiagBit,false,ColMajor,ColMajor>::run);
func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, LowerTriangular|UnitDiagBit,false,RowMajor,ColMajor>::run);
func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, LowerTriangular|UnitDiagBit,Conj, RowMajor,ColMajor>::run);
func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, Lower|UnitDiag,false,ColMajor,ColMajor>::run);
func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, Upper|UnitDiag,false,RowMajor,ColMajor>::run);
func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheLeft, Upper|UnitDiag,Conj, RowMajor,ColMajor>::run);
func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,LowerTriangular|UnitDiagBit,false,ColMajor,ColMajor>::run);
func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,LowerTriangular|UnitDiagBit,false,RowMajor,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,LowerTriangular|UnitDiagBit,Conj, RowMajor,ColMajor>::run);
func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,Lower|UnitDiag,false,ColMajor,ColMajor>::run);
func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,Upper|UnitDiag,false,RowMajor,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix<Scalar,OnTheRight,Upper|UnitDiag,Conj, RowMajor,ColMajor>::run);
init = true;
}
@ -114,14 +123,23 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m,
Scalar* b = reinterpret_cast<Scalar*>(pb);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
// TODO handle alpha
int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
if(code>=32 || func[code]==0)
if(code>=32 || func[code]==0 || *m<0 || *n <0)
{
int info=1;
xerbla_("TRSM",&info,4);
return 0;
}
func[code](*m, *n, a, *lda, b, *ldb);
return 1;
if(SIDE(*side)==LEFT)
func[code](*m, *n, a, *lda, b, *ldb);
else
func[code](*n, *m, a, *lda, b, *ldb);
if(alpha!=Scalar(1))
matrix(b,*m,*n,*ldb) *= alpha;
return 0;
}
@ -129,46 +147,46 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m,
// b = alpha*b*op(a) for side = 'R'or'r'
int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb)
{
// std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n";
typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar);
functype func[32];
static functype func[32];
static bool init = false;
if(!init)
{
for(int k=0; k<32; ++k)
func[k] = 0;
func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,UpperTriangular|0, true, ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,UpperTriangular|0, true, RowMajor,false,ColMajor,false,ColMajor>::run);
func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,UpperTriangular|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Upper|0, true, ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Lower|0, true, RowMajor,false,ColMajor,false,ColMajor>::run);
func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Lower|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,UpperTriangular|0, false,ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,UpperTriangular|0, false,ColMajor,false,RowMajor,false,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,UpperTriangular|0, false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Upper|0, false,ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Lower|0, false,ColMajor,false,RowMajor,false,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Lower|0, false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,LowerTriangular|0, true, ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,LowerTriangular|0, true, RowMajor,false,ColMajor,false,ColMajor>::run);
func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,LowerTriangular|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Lower|0, true, ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Upper|0, true, RowMajor,false,ColMajor,false,ColMajor>::run);
func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Upper|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,LowerTriangular|0, false,ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,LowerTriangular|0, false,ColMajor,false,RowMajor,false,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,LowerTriangular|0, false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Lower|0, false,ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Upper|0, false,ColMajor,false,RowMajor,false,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Upper|0, false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,UpperTriangular|UnitDiagBit,true, ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,UpperTriangular|UnitDiagBit,true, RowMajor,false,ColMajor,false,ColMajor>::run);
func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,UpperTriangular|UnitDiagBit,true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Upper|UnitDiag,true, ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Lower|UnitDiag,true, RowMajor,false,ColMajor,false,ColMajor>::run);
func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Lower|UnitDiag,true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,UpperTriangular|UnitDiagBit,false,ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,UpperTriangular|UnitDiagBit,false,ColMajor,false,RowMajor,false,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,UpperTriangular|UnitDiagBit,false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Upper|UnitDiag,false,ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Lower|UnitDiag,false,ColMajor,false,RowMajor,false,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Lower|UnitDiag,false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,LowerTriangular|UnitDiagBit,true, ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,LowerTriangular|UnitDiagBit,true, RowMajor,false,ColMajor,false,ColMajor>::run);
func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,LowerTriangular|UnitDiagBit,true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Lower|UnitDiag,true, ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Upper|UnitDiag,true, RowMajor,false,ColMajor,false,ColMajor>::run);
func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Upper|UnitDiag,true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,LowerTriangular|UnitDiagBit,false,ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,LowerTriangular|UnitDiagBit,false,ColMajor,false,RowMajor,false,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,LowerTriangular|UnitDiagBit,false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Lower|UnitDiag,false,ColMajor,false,ColMajor,false,ColMajor>::run);
func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Upper|UnitDiag,false,ColMajor,false,RowMajor,false,ColMajor>::run);
func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix<Scalar,Upper|UnitDiag,false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
init = true;
}
@ -178,10 +196,21 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m,
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
if(code>=32 || func[code]==0)
if(code>=32 || func[code]==0 || *m<0 || *n <0)
{
int info=1;
xerbla_("TRMM",&info,4);
return 0;
}
func[code](*m, *n, a, *lda, b, *ldb, b, *ldb, alpha);
// FIXME find a way to avoid this copy
Matrix<Scalar,Dynamic,Dynamic> tmp = matrix(b,*m,*n,*ldb);
matrix(b,*m,*n,*ldb).setZero();
if(SIDE(*side)==LEFT)
func[code](*m, *n, a, *lda, tmp.data(), tmp.outerStride(), b, *ldb, alpha);
else
func[code](*n, *m, tmp.data(), tmp.outerStride(), a, *lda, b, *ldb, alpha);
return 1;
}
@ -189,14 +218,26 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m,
// c = alpha*b*a + beta*c for side = 'R'or'r
int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
{
// std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << " "
// << pa << " " << pb << " " << pc << "\n";
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
if(*m<0 || *n<0)
{
int info=1;
xerbla_("SYMM",&info,4);
return 0;
}
if(beta!=Scalar(1))
matrix(c, *m, *n, *ldc) *= beta;
if(beta==Scalar(0))
matrix(c, *m, *n, *ldc).setZero();
else
matrix(c, *m, *n, *ldc) *= beta;
if(SIDE(*side)==LEFT)
if(UPLO(*uplo)==UP)
@ -215,15 +256,16 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa
else
return 0;
return 1;
return 0;
}
// c = alpha*a*a' + beta*c for op = 'N'or'n'
// c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c'
int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc)
{
// std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << "\n";
typedef void (*functype)(int, int, const Scalar *, int, Scalar *, int, Scalar);
functype func[8];
static functype func[8];
static bool init = false;
if(!init)
@ -231,13 +273,13 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
for(int k=0; k<8; ++k)
func[k] = 0;
func[NOTR | (UP << 2)] = (ei_selfadjoint_product<Scalar,ColMajor,ColMajor,true, UpperTriangular>::run);
func[TR | (UP << 2)] = (ei_selfadjoint_product<Scalar,RowMajor,ColMajor,false,UpperTriangular>::run);
func[ADJ | (UP << 2)] = (ei_selfadjoint_product<Scalar,RowMajor,ColMajor,false,UpperTriangular>::run);
func[NOTR | (UP << 2)] = (ei_selfadjoint_product<Scalar,ColMajor,ColMajor,true, Upper>::run);
func[TR | (UP << 2)] = (ei_selfadjoint_product<Scalar,RowMajor,ColMajor,false,Upper>::run);
func[ADJ | (UP << 2)] = (ei_selfadjoint_product<Scalar,RowMajor,ColMajor,false,Upper>::run);
func[NOTR | (LO << 2)] = (ei_selfadjoint_product<Scalar,ColMajor,ColMajor,true, LowerTriangular>::run);
func[TR | (LO << 2)] = (ei_selfadjoint_product<Scalar,RowMajor,ColMajor,false,LowerTriangular>::run);
func[ADJ | (LO << 2)] = (ei_selfadjoint_product<Scalar,RowMajor,ColMajor,false,LowerTriangular>::run);
func[NOTR | (LO << 2)] = (ei_selfadjoint_product<Scalar,ColMajor,ColMajor,true, Lower>::run);
func[TR | (LO << 2)] = (ei_selfadjoint_product<Scalar,RowMajor,ColMajor,false,Lower>::run);
func[ADJ | (LO << 2)] = (ei_selfadjoint_product<Scalar,RowMajor,ColMajor,false,Lower>::run);
init = true;
}
@ -248,8 +290,12 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
int code = OP(*op) | (UPLO(*uplo) << 2);
if(code>=8 || func[code]==0)
if(code>=8 || func[code]==0 || *n<0 || *k<0)
{
int info=1;
xerbla_("SYRK",&info,4);
return 0;
}
if(beta!=Scalar(1))
matrix(c, *n, *n, *ldc) *= beta;
@ -314,7 +360,7 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa
int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc)
{
typedef void (*functype)(int, int, const Scalar *, int, Scalar *, int, Scalar);
functype func[8];
static functype func[8];
static bool init = false;
if(!init)
@ -322,11 +368,11 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
for(int k=0; k<8; ++k)
func[k] = 0;
func[NOTR | (UP << 2)] = (ei_selfadjoint_product<Scalar,ColMajor,ColMajor,true, UpperTriangular>::run);
func[ADJ | (UP << 2)] = (ei_selfadjoint_product<Scalar,RowMajor,ColMajor,false,UpperTriangular>::run);
func[NOTR | (UP << 2)] = (ei_selfadjoint_product<Scalar,ColMajor,ColMajor,true, Upper>::run);
func[ADJ | (UP << 2)] = (ei_selfadjoint_product<Scalar,RowMajor,ColMajor,false,Upper>::run);
func[NOTR | (LO << 2)] = (ei_selfadjoint_product<Scalar,ColMajor,ColMajor,true, LowerTriangular>::run);
func[ADJ | (LO << 2)] = (ei_selfadjoint_product<Scalar,RowMajor,ColMajor,false,LowerTriangular>::run);
func[NOTR | (LO << 2)] = (ei_selfadjoint_product<Scalar,ColMajor,ColMajor,true, Lower>::run);
func[ADJ | (LO << 2)] = (ei_selfadjoint_product<Scalar,RowMajor,ColMajor,false,Lower>::run);
init = true;
}