mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-27 07:29:52 +08:00
Merged in rmlarsen/eigen1 (pull request PR-441)
Reduce the number of template specializations of classes related to tensor contraction to reduce binary size.
This commit is contained in:
commit
34539c4af4
@ -177,9 +177,9 @@ struct NoOpOutputKernel {
|
||||
*/
|
||||
template <typename Index, typename Scalar>
|
||||
EIGEN_ALWAYS_INLINE void operator()(
|
||||
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
|
||||
const TensorContractionParams& params, Index i, Index j, Index num_rows,
|
||||
Index num_cols) const {}
|
||||
const OutputKernel::OutputMapper<Index, Scalar>& /*output_mapper*/,
|
||||
const TensorContractionParams& /*params*/, Index /*i*/,
|
||||
Index /*j*/, Index /*num_rows*/, Index /*num_cols*/) const {}
|
||||
};
|
||||
|
||||
template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType = const NoOpOutputKernel>
|
||||
@ -467,42 +467,58 @@ struct TensorContractionEvaluatorBase
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const {
|
||||
if (this->m_lhs_inner_dim_contiguous) {
|
||||
if (this->m_rhs_inner_dim_contiguous) {
|
||||
if (this->m_rhs_inner_dim_reordered) {
|
||||
static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer);
|
||||
}
|
||||
else {
|
||||
static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (this->m_rhs_inner_dim_reordered) {
|
||||
static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer);
|
||||
}
|
||||
else {
|
||||
static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer);
|
||||
}
|
||||
}
|
||||
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
|
||||
if (this->m_lhs_inner_dim_contiguous) { \
|
||||
if (this->m_rhs_inner_dim_contiguous) { \
|
||||
if (this->m_rhs_inner_dim_reordered) { \
|
||||
METHOD<true, true, true, ALIGNMENT>ARGS; \
|
||||
} \
|
||||
else { \
|
||||
METHOD<true, true, false, ALIGNMENT>ARGS; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
if (this->m_rhs_inner_dim_reordered) { \
|
||||
METHOD<true, false, true, ALIGNMENT>ARGS; \
|
||||
} \
|
||||
else { \
|
||||
METHOD<true, false, false, ALIGNMENT>ARGS; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
if (this->m_rhs_inner_dim_contiguous) { \
|
||||
if (this->m_rhs_inner_dim_reordered) { \
|
||||
METHOD<false, true, true, ALIGNMENT>ARGS; \
|
||||
} \
|
||||
else { \
|
||||
METHOD<false, true, false, ALIGNMENT>ARGS; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
if (this->m_rhs_inner_dim_reordered) { \
|
||||
METHOD<false, false, true, ALIGNMENT>ARGS; \
|
||||
} \
|
||||
else { \
|
||||
METHOD<false, false, false, ALIGNMENT>ARGS; \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
else {
|
||||
if (this->m_rhs_inner_dim_contiguous) {
|
||||
if (this->m_rhs_inner_dim_reordered) {
|
||||
static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer);
|
||||
}
|
||||
else {
|
||||
static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (this->m_rhs_inner_dim_reordered) {
|
||||
static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer);
|
||||
}
|
||||
else {
|
||||
static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const {
|
||||
static_cast<const Derived*>(this)->template evalProduct<Unaligned>(buffer);
|
||||
}
|
||||
|
||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
|
||||
bool rhs_inner_dim_reordered, int Alignment>
|
||||
void evalProductSequential(Scalar* buffer) const {
|
||||
if (this->m_j_size == 1) {
|
||||
this->template evalGemv<lhs_inner_dim_contiguous,
|
||||
rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
|
||||
Alignment>(buffer);
|
||||
} else {
|
||||
this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
|
||||
rhs_inner_dim_reordered, Alignment>(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
@ -623,7 +639,7 @@ struct TensorContractionEvaluatorBase
|
||||
OutputMapper output(buffer, m);
|
||||
|
||||
// Sizes of the blocks to load in cache. See the Goto paper for details.
|
||||
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1);
|
||||
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n, 1);
|
||||
const Index kc = blocking.kc();
|
||||
const Index mc = numext::mini(m, blocking.mc());
|
||||
const Index nc = numext::mini(n, blocking.nc());
|
||||
@ -976,14 +992,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
|
||||
Base(op, device) { }
|
||||
|
||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||
EIGEN_DEVICE_FUNC void evalProduct(Scalar* buffer) const {
|
||||
if (this->m_j_size == 1) {
|
||||
this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
|
||||
return;
|
||||
}
|
||||
|
||||
this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
|
||||
template <int Alignment>
|
||||
void evalProduct(Scalar* buffer) const {
|
||||
TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Alignment, (buffer));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -21,13 +21,10 @@ enum {
|
||||
|
||||
|
||||
// Default Blocking Strategy
|
||||
template <typename LhsMapper, typename RhsMapper, typename Index, int ShardingType=ShardByCol>
|
||||
template <typename LhsScalar, typename RhsScalar, typename Index, int ShardingType=ShardByCol>
|
||||
class TensorContractionBlocking {
|
||||
public:
|
||||
|
||||
typedef typename LhsMapper::Scalar LhsScalar;
|
||||
typedef typename RhsMapper::Scalar RhsScalar;
|
||||
|
||||
/*
|
||||
adding EIGEN_DEVICE_FUNC unconditionally to 'TensorContractionBlocking' constructor in `TensorContractionBlocking.h`
|
||||
requires adding EIGEN_DEVICE_FUNC to `computeProductBlockingSizes` in `GeneralBlockPanelKernel.h`
|
||||
@ -41,7 +38,7 @@ class TensorContractionBlocking {
|
||||
../Eigen/src/Core/products/GeneralBlockPanelKernel.h(57): error #2901:
|
||||
dynamic initialization is not supported for function-scope static variables within a __device__/__global__ function
|
||||
*/
|
||||
|
||||
|
||||
#if !defined(EIGEN_HIPCC)
|
||||
EIGEN_DEVICE_FUNC
|
||||
#endif
|
||||
|
@ -71,8 +71,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
TensorEvaluator(const XprType& op, const Device& device) :
|
||||
Base(op, device) {}
|
||||
|
||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
|
||||
bool rhs_inner_dim_reordered, int Alignment>
|
||||
template <int Alignment>
|
||||
void evalProduct(Scalar* buffer) const {
|
||||
const Index m = this->m_i_size;
|
||||
const Index n = this->m_j_size;
|
||||
@ -96,39 +95,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
}
|
||||
#endif
|
||||
|
||||
typedef
|
||||
typename internal::remove_const<typename EvalLeftArgType::Scalar>::type
|
||||
LhsScalar;
|
||||
typedef
|
||||
typename internal::remove_const<typename EvalRightArgType::Scalar>::type
|
||||
RhsScalar;
|
||||
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
|
||||
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
|
||||
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
|
||||
typedef internal::TensorContractionInputMapper<
|
||||
LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
|
||||
contract_t, internal::packet_traits<LhsScalar>::size,
|
||||
lhs_inner_dim_contiguous, false, Unaligned>
|
||||
LhsMapper;
|
||||
typedef internal::TensorContractionInputMapper<
|
||||
RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
|
||||
contract_t, internal::packet_traits<RhsScalar>::size,
|
||||
rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
|
||||
RhsMapper;
|
||||
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
|
||||
typedef internal::gemm_pack_lhs<LhsScalar, Index,
|
||||
typename LhsMapper::SubMapper, Traits::mr,
|
||||
Traits::LhsProgress, ColMajor>
|
||||
LhsPacker;
|
||||
typedef internal::gemm_pack_rhs<
|
||||
RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor>
|
||||
RhsPacker;
|
||||
typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper,
|
||||
Traits::mr, Traits::nr, false, false>
|
||||
GebpKernel;
|
||||
|
||||
|
||||
|
||||
// Compute a set of algorithm parameters:
|
||||
// - kernel block sizes (bm, bn, bk)
|
||||
// - task grain sizes (number of kernels executed per task: gm, gn)
|
||||
@ -158,14 +124,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
// Again, we don't know number of threads yet, so we use 2.
|
||||
Index bm, bn, bk;
|
||||
if (shard_by_col) {
|
||||
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
|
||||
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
|
||||
internal::ShardByCol>
|
||||
blocking(k, m, n, 2);
|
||||
bm = blocking.mc();
|
||||
bn = blocking.nc();
|
||||
bk = blocking.kc();
|
||||
} else {
|
||||
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
|
||||
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
|
||||
internal::ShardByRow>
|
||||
blocking(k, m, n, 2);
|
||||
bm = blocking.mc();
|
||||
@ -187,29 +153,22 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
if (n == 1) num_threads = 1;
|
||||
|
||||
if (num_threads == 1) {
|
||||
// The single-threaded algorithm should be faster in this case.
|
||||
if (n == 1)
|
||||
this->template evalGemv<lhs_inner_dim_contiguous,
|
||||
rhs_inner_dim_contiguous,
|
||||
rhs_inner_dim_reordered, Alignment>(buffer);
|
||||
else
|
||||
this->template evalGemm<lhs_inner_dim_contiguous,
|
||||
rhs_inner_dim_contiguous,
|
||||
rhs_inner_dim_reordered, Alignment>(buffer);
|
||||
TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential,
|
||||
Unaligned, (buffer));
|
||||
return;
|
||||
}
|
||||
|
||||
// Now that we know number of threads, recalculate sharding and blocking.
|
||||
shard_by_col = shardByCol(m, n, num_threads);
|
||||
if (shard_by_col) {
|
||||
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
|
||||
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
|
||||
internal::ShardByCol>
|
||||
blocking(k, m, n, num_threads);
|
||||
bm = blocking.mc();
|
||||
bn = blocking.nc();
|
||||
bk = blocking.kc();
|
||||
} else {
|
||||
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
|
||||
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
|
||||
internal::ShardByRow>
|
||||
blocking(k, m, n, num_threads);
|
||||
bm = blocking.mc();
|
||||
@ -257,34 +216,55 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
// more important in this case.
|
||||
if ((shard_by_col ? nm : nn) == 1) parallel_pack = false;
|
||||
|
||||
LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides,
|
||||
this->m_i_strides, this->m_left_contracting_strides,
|
||||
this->m_k_strides);
|
||||
#define CONTEXT_ARGS \
|
||||
(this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
|
||||
nn0, shard_by_col, parallel_pack) \
|
||||
.run()
|
||||
|
||||
RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides,
|
||||
this->m_j_strides, this->m_right_contracting_strides,
|
||||
this->m_k_strides);
|
||||
TENSOR_CONTRACTION_DISPATCH(Context, Alignment, CONTEXT_ARGS);
|
||||
|
||||
#undef CONTEXT_ARGS
|
||||
|
||||
Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper,
|
||||
OutputMapper>(this, num_threads, lhs, rhs, buffer, m, n,
|
||||
k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0,
|
||||
shard_by_col, parallel_pack)
|
||||
.run();
|
||||
}
|
||||
|
||||
// Context coordinates a single parallel gemm operation.
|
||||
template <typename LhsPacker, typename RhsPacker, typename GebpKernel,
|
||||
typename LhsMapper, typename RhsMapper, typename OutputMapper>
|
||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
|
||||
bool rhs_inner_dim_reordered, int Alignment>
|
||||
class Context {
|
||||
public:
|
||||
Context(const Self* self, int num_threads, LhsMapper& lhs,
|
||||
RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
|
||||
Index bn, Index bk, Index nm, Index nn, Index nk, Index gm,
|
||||
Index gn, Index nm0, Index nn0, bool shard_by_col,
|
||||
typedef internal::TensorContractionInputMapper<
|
||||
LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
|
||||
contract_t, internal::packet_traits<LhsScalar>::size,
|
||||
lhs_inner_dim_contiguous, false, Unaligned>
|
||||
LhsMapper;
|
||||
typedef internal::TensorContractionInputMapper<
|
||||
RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
|
||||
contract_t, internal::packet_traits<RhsScalar>::size,
|
||||
rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
|
||||
RhsMapper;
|
||||
typedef internal::gemm_pack_lhs<LhsScalar, Index,
|
||||
typename LhsMapper::SubMapper, Traits::mr,
|
||||
Traits::LhsProgress, ColMajor>
|
||||
LhsPacker;
|
||||
typedef internal::gemm_pack_rhs<
|
||||
RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor>
|
||||
RhsPacker;
|
||||
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
|
||||
typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper,
|
||||
Traits::mr, Traits::nr, false, false>
|
||||
GebpKernel;
|
||||
|
||||
Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn,
|
||||
Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk,
|
||||
Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col,
|
||||
bool parallel_pack)
|
||||
: device_(self->m_device),
|
||||
lhs_(lhs),
|
||||
rhs_(rhs),
|
||||
lhs_(self->m_leftImpl, self->m_left_nocontract_strides,
|
||||
self->m_i_strides, self->m_left_contracting_strides,
|
||||
self->m_k_strides),
|
||||
rhs_(self->m_rightImpl, self->m_right_nocontract_strides,
|
||||
self->m_j_strides, self->m_right_contracting_strides,
|
||||
self->m_k_strides),
|
||||
buffer_(buffer),
|
||||
output_(buffer, tm),
|
||||
output_kernel_(self->m_output_kernel),
|
||||
@ -376,8 +356,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
private:
|
||||
Notification done_;
|
||||
const Device& device_;
|
||||
LhsMapper& lhs_;
|
||||
RhsMapper& rhs_;
|
||||
LhsMapper lhs_;
|
||||
RhsMapper rhs_;
|
||||
Scalar* const buffer_;
|
||||
OutputMapper output_;
|
||||
OutputKernelType output_kernel_;
|
||||
|
Loading…
Reference in New Issue
Block a user