Merged in rmlarsen/eigen1 (pull request PR-441)

Reduce the number of template specializations of classes related to tensor contraction to reduce binary size.
This commit is contained in:
Gael Guennebaud 2018-07-30 11:26:24 +00:00
commit 34539c4af4
3 changed files with 109 additions and 121 deletions

View File

@ -177,9 +177,9 @@ struct NoOpOutputKernel {
*/
template <typename Index, typename Scalar>
EIGEN_ALWAYS_INLINE void operator()(
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
const TensorContractionParams& params, Index i, Index j, Index num_rows,
Index num_cols) const {}
const OutputKernel::OutputMapper<Index, Scalar>& /*output_mapper*/,
const TensorContractionParams& /*params*/, Index /*i*/,
Index /*j*/, Index /*num_rows*/, Index /*num_cols*/) const {}
};
template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType = const NoOpOutputKernel>
@ -467,42 +467,58 @@ struct TensorContractionEvaluatorBase
}
}
EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const {
if (this->m_lhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_reordered) {
static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer);
}
else {
static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer);
}
}
else {
if (this->m_rhs_inner_dim_reordered) {
static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer);
}
else {
static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer);
}
}
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
if (this->m_lhs_inner_dim_contiguous) { \
if (this->m_rhs_inner_dim_contiguous) { \
if (this->m_rhs_inner_dim_reordered) { \
METHOD<true, true, true, ALIGNMENT>ARGS; \
} \
else { \
METHOD<true, true, false, ALIGNMENT>ARGS; \
} \
} \
else { \
if (this->m_rhs_inner_dim_reordered) { \
METHOD<true, false, true, ALIGNMENT>ARGS; \
} \
else { \
METHOD<true, false, false, ALIGNMENT>ARGS; \
} \
} \
} \
else { \
if (this->m_rhs_inner_dim_contiguous) { \
if (this->m_rhs_inner_dim_reordered) { \
METHOD<false, true, true, ALIGNMENT>ARGS; \
} \
else { \
METHOD<false, true, false, ALIGNMENT>ARGS; \
} \
} \
else { \
if (this->m_rhs_inner_dim_reordered) { \
METHOD<false, false, true, ALIGNMENT>ARGS; \
} \
else { \
METHOD<false, false, false, ALIGNMENT>ARGS; \
} \
} \
}
else {
if (this->m_rhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_reordered) {
static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer);
}
else {
static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer);
}
}
else {
if (this->m_rhs_inner_dim_reordered) {
static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer);
}
else {
static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer);
}
}
EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const {
static_cast<const Derived*>(this)->template evalProduct<Unaligned>(buffer);
}
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered, int Alignment>
void evalProductSequential(Scalar* buffer) const {
if (this->m_j_size == 1) {
this->template evalGemv<lhs_inner_dim_contiguous,
rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
Alignment>(buffer);
} else {
this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, Alignment>(buffer);
}
}
@ -623,7 +639,7 @@ struct TensorContractionEvaluatorBase
OutputMapper output(buffer, m);
// Sizes of the blocks to load in cache. See the Goto paper for details.
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1);
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n, 1);
const Index kc = blocking.kc();
const Index mc = numext::mini(m, blocking.mc());
const Index nc = numext::mini(n, blocking.nc());
@ -976,14 +992,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
Base(op, device) { }
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
EIGEN_DEVICE_FUNC void evalProduct(Scalar* buffer) const {
if (this->m_j_size == 1) {
this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
return;
}
this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
template <int Alignment>
void evalProduct(Scalar* buffer) const {
TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Alignment, (buffer));
}
};

View File

@ -21,13 +21,10 @@ enum {
// Default Blocking Strategy
template <typename LhsMapper, typename RhsMapper, typename Index, int ShardingType=ShardByCol>
template <typename LhsScalar, typename RhsScalar, typename Index, int ShardingType=ShardByCol>
class TensorContractionBlocking {
public:
typedef typename LhsMapper::Scalar LhsScalar;
typedef typename RhsMapper::Scalar RhsScalar;
/*
adding EIGEN_DEVICE_FUNC unconditionally to 'TensorContractionBlocking' constructor in `TensorContractionBlocking.h`
requires adding EIGEN_DEVICE_FUNC to `computeProductBlockingSizes` in `GeneralBlockPanelKernel.h`
@ -41,7 +38,7 @@ class TensorContractionBlocking {
../Eigen/src/Core/products/GeneralBlockPanelKernel.h(57): error #2901:
dynamic initialization is not supported for function-scope static variables within a __device__/__global__ function
*/
#if !defined(EIGEN_HIPCC)
EIGEN_DEVICE_FUNC
#endif

View File

@ -71,8 +71,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
TensorEvaluator(const XprType& op, const Device& device) :
Base(op, device) {}
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered, int Alignment>
template <int Alignment>
void evalProduct(Scalar* buffer) const {
const Index m = this->m_i_size;
const Index n = this->m_j_size;
@ -96,39 +95,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
#endif
typedef
typename internal::remove_const<typename EvalLeftArgType::Scalar>::type
LhsScalar;
typedef
typename internal::remove_const<typename EvalRightArgType::Scalar>::type
RhsScalar;
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
typedef internal::TensorContractionInputMapper<
LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
contract_t, internal::packet_traits<LhsScalar>::size,
lhs_inner_dim_contiguous, false, Unaligned>
LhsMapper;
typedef internal::TensorContractionInputMapper<
RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
contract_t, internal::packet_traits<RhsScalar>::size,
rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
RhsMapper;
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
typedef internal::gemm_pack_lhs<LhsScalar, Index,
typename LhsMapper::SubMapper, Traits::mr,
Traits::LhsProgress, ColMajor>
LhsPacker;
typedef internal::gemm_pack_rhs<
RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor>
RhsPacker;
typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper,
Traits::mr, Traits::nr, false, false>
GebpKernel;
// Compute a set of algorithm parameters:
// - kernel block sizes (bm, bn, bk)
// - task grain sizes (number of kernels executed per task: gm, gn)
@ -158,14 +124,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// Again, we don't know number of threads yet, so we use 2.
Index bm, bn, bk;
if (shard_by_col) {
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
internal::ShardByCol>
blocking(k, m, n, 2);
bm = blocking.mc();
bn = blocking.nc();
bk = blocking.kc();
} else {
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
internal::ShardByRow>
blocking(k, m, n, 2);
bm = blocking.mc();
@ -187,29 +153,22 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
if (n == 1) num_threads = 1;
if (num_threads == 1) {
// The single-threaded algorithm should be faster in this case.
if (n == 1)
this->template evalGemv<lhs_inner_dim_contiguous,
rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, Alignment>(buffer);
else
this->template evalGemm<lhs_inner_dim_contiguous,
rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, Alignment>(buffer);
TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential,
Unaligned, (buffer));
return;
}
// Now that we know number of threads, recalculate sharding and blocking.
shard_by_col = shardByCol(m, n, num_threads);
if (shard_by_col) {
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
internal::ShardByCol>
blocking(k, m, n, num_threads);
bm = blocking.mc();
bn = blocking.nc();
bk = blocking.kc();
} else {
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
internal::ShardByRow>
blocking(k, m, n, num_threads);
bm = blocking.mc();
@ -257,34 +216,55 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// more important in this case.
if ((shard_by_col ? nm : nn) == 1) parallel_pack = false;
LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides,
this->m_i_strides, this->m_left_contracting_strides,
this->m_k_strides);
#define CONTEXT_ARGS \
(this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
nn0, shard_by_col, parallel_pack) \
.run()
RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides,
this->m_j_strides, this->m_right_contracting_strides,
this->m_k_strides);
TENSOR_CONTRACTION_DISPATCH(Context, Alignment, CONTEXT_ARGS);
#undef CONTEXT_ARGS
Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper,
OutputMapper>(this, num_threads, lhs, rhs, buffer, m, n,
k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0,
shard_by_col, parallel_pack)
.run();
}
// Context coordinates a single parallel gemm operation.
template <typename LhsPacker, typename RhsPacker, typename GebpKernel,
typename LhsMapper, typename RhsMapper, typename OutputMapper>
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered, int Alignment>
class Context {
public:
Context(const Self* self, int num_threads, LhsMapper& lhs,
RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
Index bn, Index bk, Index nm, Index nn, Index nk, Index gm,
Index gn, Index nm0, Index nn0, bool shard_by_col,
typedef internal::TensorContractionInputMapper<
LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
contract_t, internal::packet_traits<LhsScalar>::size,
lhs_inner_dim_contiguous, false, Unaligned>
LhsMapper;
typedef internal::TensorContractionInputMapper<
RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
contract_t, internal::packet_traits<RhsScalar>::size,
rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
RhsMapper;
typedef internal::gemm_pack_lhs<LhsScalar, Index,
typename LhsMapper::SubMapper, Traits::mr,
Traits::LhsProgress, ColMajor>
LhsPacker;
typedef internal::gemm_pack_rhs<
RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor>
RhsPacker;
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper,
Traits::mr, Traits::nr, false, false>
GebpKernel;
Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn,
Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk,
Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col,
bool parallel_pack)
: device_(self->m_device),
lhs_(lhs),
rhs_(rhs),
lhs_(self->m_leftImpl, self->m_left_nocontract_strides,
self->m_i_strides, self->m_left_contracting_strides,
self->m_k_strides),
rhs_(self->m_rightImpl, self->m_right_nocontract_strides,
self->m_j_strides, self->m_right_contracting_strides,
self->m_k_strides),
buffer_(buffer),
output_(buffer, tm),
output_kernel_(self->m_output_kernel),
@ -376,8 +356,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
private:
Notification done_;
const Device& device_;
LhsMapper& lhs_;
RhsMapper& rhs_;
LhsMapper lhs_;
RhsMapper rhs_;
Scalar* const buffer_;
OutputMapper output_;
OutputKernelType output_kernel_;