Fix degenerate cases in syrk and trsm

This commit is contained in:
Gael Guennebaud 2015-11-30 22:20:31 +01:00
parent e7a1c48185
commit 1d906d883d
2 changed files with 13 additions and 3 deletions

View File

@ -203,8 +203,6 @@ void evaluateProductBlockingSizesHeuristic(Index& k, Index& m, Index& n, Index n
const Index actual_l2 = 1572864; // == 1.5 MB
#endif
// Here, nc is chosen such that a block of kc x nc of the rhs fit within half of L2.
// The second half is implicitly reserved to access the result and lhs coefficients.
// When k<max_kc, then nc can arbitrarily growth. In practice, it seems to be fruitful

View File

@ -6,7 +6,7 @@
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include <iostream>
#include "common.h"
int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
@ -133,6 +133,9 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m,
if(info)
return xerbla_(SCALAR_SUFFIX_UP"TRSM ",&info,6);
if(*m==0 || *n==0)
return 0;
int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
if(SIDE(*side)==LEFT)
@ -358,6 +361,9 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
}
if(*n==0 || *k==0)
return 0;
#if ISCOMPLEX
// FIXME add support for symmetric complex matrix
if(UPLO(*uplo)==UP)
@ -392,6 +398,8 @@ int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
// std::cerr << "in syr2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n";
int info = 0;
if(UPLO(*uplo)==INVALID) info = 1;
else if(OP(*op)==INVALID) info = 2;
@ -506,6 +514,8 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa
// c = alpha*conj(a')*a + beta*c for op = 'C'or'c'
int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc)
{
// std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&);
static functype func[8];
@ -577,6 +587,8 @@ int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
RealScalar beta = *pbeta;
// std::cerr << "in her2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n";
int info = 0;
if(UPLO(*uplo)==INVALID) info = 1;
else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2;