Create the ability to disable the specialized gemm_pack_rhs in Eigen (only PPC) for TensorFlow

This commit is contained in:
Chip Kerchner 2021-06-30 23:05:04 +00:00 committed by Rasmus Munk Larsen
parent 60400334a9
commit 91e99ec1e0

View File

@ -11,6 +11,10 @@
#ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H
#define EIGEN_MATRIX_PRODUCT_ALTIVEC_H
#ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK
#define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
#endif
#include "MatrixProductCommon.h"
// Since LLVM doesn't support dynamic dispatching, force either always MMA or VSX
@ -2423,6 +2427,7 @@ void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Co
pack(blockA, lhs, depth, rows, stride, offset);
}
#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
{
@ -2450,6 +2455,7 @@ void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode
dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, false> pack;
pack(blockB, rhs, depth, cols, stride, offset);
}
#endif
template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
@ -2478,6 +2484,7 @@ void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Con
dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, true> pack;
pack(blockA, lhs, depth, rows, stride, offset);
}
template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
{
@ -2506,6 +2513,7 @@ void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet,
pack(blockA, lhs, depth, rows, stride, offset);
}
#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
{
@ -2533,6 +2541,7 @@ void gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, false> pack;
pack(blockB, rhs, depth, cols, stride, offset);
}
#endif
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>