Re-enable A^T*A action in BTL

This commit is contained in:
Gael Guennebaud 2016-12-02 11:32:03 +01:00
parent 8c24723a09
commit d2718d662c
8 changed files with 34 additions and 32 deletions

View File

@ -6,7 +6,7 @@
#include "action_atv_product.hh"
#include "action_matrix_matrix_product.hh"
// #include "action_ata_product.hh"
#include "action_ata_product.hh"
#include "action_aat_product.hh"
#include "action_trisolve.hh"

View File

@ -46,9 +46,9 @@ public :
BLAS_FUNC(gemm)(&notrans,&notrans,&N,&N,&N,&fone,A,&N,B,&N,&fzero,X,&N);
}
// static inline void ata_product(gene_matrix & A, gene_matrix & X, int N){
// ssyrk_(&lower,&trans,&N,&N,&fone,A,&N,&fzero,X,&N);
// }
static inline void ata_product(gene_matrix & A, gene_matrix & X, int N){
BLAS_FUNC(syrk)(&lower,&trans,&N,&N,&fone,A,&N,&fzero,X,&N);
}
static inline void aat_product(gene_matrix & A, gene_matrix & X, int N){
BLAS_FUNC(syrk)(&lower,&notrans,&N,&N,&fone,A,&N,&fzero,X,&N);

View File

@ -48,7 +48,7 @@ int main()
bench<Action_rot<blas_interface<REAL_TYPE> > >(MIN_AXPY,MAX_AXPY,NB_POINT);
bench<Action_matrix_matrix_product<blas_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);
// bench<Action_ata_product<blas_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);
bench<Action_ata_product<blas_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);
bench<Action_aat_product<blas_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);
bench<Action_trisolve<blas_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);

View File

@ -78,18 +78,18 @@ public :
cible[i][j]=source[i][j];
}
// static inline void ata_product(const gene_matrix & A, gene_matrix & X, int N)
// {
// real somme;
// for (int j=0;j<N;j++){
// for (int i=0;i<N;i++){
// somme=0.0;
// for (int k=0;k<N;k++)
// somme += A[i][k]*A[j][k];
// X[j][i]=somme;
// }
// }
// }
static inline void ata_product(const gene_matrix & A, gene_matrix & X, int N)
{
real somme;
for (int j=0;j<N;j++){
for (int i=0;i<N;i++){
somme=0.0;
for (int k=0;k<N;k++)
somme += A[i][k]*A[j][k];
X[j][i]=somme;
}
}
}
static inline void aat_product(const gene_matrix & A, gene_matrix & X, int N)
{

View File

@ -80,35 +80,35 @@ public :
}
}
static inline void matrix_matrix_product(const gene_matrix & A, const gene_matrix & B, gene_matrix & X, int N){
static EIGEN_DONT_INLINE void matrix_matrix_product(const gene_matrix & A, const gene_matrix & B, gene_matrix & X, int N){
X = (A*B);
}
static inline void transposed_matrix_matrix_product(const gene_matrix & A, const gene_matrix & B, gene_matrix & X, int N){
static EIGEN_DONT_INLINE void transposed_matrix_matrix_product(const gene_matrix & A, const gene_matrix & B, gene_matrix & X, int N){
X = (trans(A)*trans(B));
}
static inline void ata_product(const gene_matrix & A, gene_matrix & X, int N){
static EIGEN_DONT_INLINE void ata_product(const gene_matrix & A, gene_matrix & X, int N){
X = (trans(A)*A);
}
static inline void aat_product(const gene_matrix & A, gene_matrix & X, int N){
static EIGEN_DONT_INLINE void aat_product(const gene_matrix & A, gene_matrix & X, int N){
X = (A*trans(A));
}
static inline void matrix_vector_product(gene_matrix & A, gene_vector & B, gene_vector & X, int N){
static EIGEN_DONT_INLINE void matrix_vector_product(gene_matrix & A, gene_vector & B, gene_vector & X, int N){
X = (A*B);
}
static inline void atv_product(gene_matrix & A, gene_vector & B, gene_vector & X, int N){
static EIGEN_DONT_INLINE void atv_product(gene_matrix & A, gene_vector & B, gene_vector & X, int N){
X = (trans(A)*B);
}
static inline void axpy(const real coef, const gene_vector & X, gene_vector & Y, int N){
static EIGEN_DONT_INLINE void axpy(const real coef, const gene_vector & X, gene_vector & Y, int N){
Y += coef * X;
}
static inline void axpby(real a, const gene_vector & X, real b, gene_vector & Y, int N){
static EIGEN_DONT_INLINE void axpby(real a, const gene_vector & X, real b, gene_vector & Y, int N){
Y = a*X + b*Y;
}

View File

@ -30,9 +30,9 @@ int main()
bench<Action_matrix_vector_product<blaze_interface<REAL_TYPE> > >(MIN_MV,MAX_MV,NB_POINT);
bench<Action_atv_product<blaze_interface<REAL_TYPE> > >(MIN_MV,MAX_MV,NB_POINT);
// bench<Action_matrix_matrix_product<blaze_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);
// bench<Action_ata_product<blaze_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);
// bench<Action_aat_product<blaze_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);
bench<Action_matrix_matrix_product<blaze_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);
bench<Action_ata_product<blaze_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);
bench<Action_aat_product<blaze_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);
return 0;
}

View File

@ -92,9 +92,11 @@ public :
X.noalias() = A.transpose()*B.transpose();
}
// static inline void ata_product(const gene_matrix & A, gene_matrix & X, int /*N*/){
static inline void ata_product(const gene_matrix & A, gene_matrix & X, int /*N*/){
//X.noalias() = A.transpose()*A;
// }
X.template triangularView<Lower>().setZero();
X.template selfadjointView<Lower>().rankUpdate(A.transpose());
}
static inline void aat_product(const gene_matrix & A, gene_matrix & X, int /*N*/){
X.template triangularView<Lower>().setZero();

View File

@ -25,7 +25,7 @@ BTL_MAIN;
int main()
{
bench<Action_matrix_matrix_product<eigen3_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);
// bench<Action_ata_product<eigen3_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);
bench<Action_ata_product<eigen3_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);
bench<Action_aat_product<eigen3_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);
bench<Action_trmm<eigen3_interface<REAL_TYPE> > >(MIN_MM,MAX_MM,NB_POINT);