mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Move the evalGemm method into the TensorContractionEvaluatorBase class to make it accessible from both the single and multithreaded contraction evaluators.
This commit is contained in:
parent
1a16fb1532
commit
c8e8f93d6c
@ -426,6 +426,99 @@ struct TensorContractionEvaluatorBase
|
|||||||
buffer, resIncr, alpha);
|
buffer, resIncr, alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||||
|
EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const {
|
||||||
|
// columns in left side, rows in right side
|
||||||
|
const Index k = this->m_k_size;
|
||||||
|
|
||||||
|
// rows in left side
|
||||||
|
const Index m = this->m_i_size;
|
||||||
|
|
||||||
|
// columns in right side
|
||||||
|
const Index n = this->m_j_size;
|
||||||
|
|
||||||
|
// zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
|
||||||
|
this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
|
||||||
|
|
||||||
|
// define mr, nr, and all of my data mapper types
|
||||||
|
typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
|
||||||
|
typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
|
||||||
|
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
|
||||||
|
|
||||||
|
const Index nr = Traits::nr;
|
||||||
|
const Index mr = Traits::mr;
|
||||||
|
|
||||||
|
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
|
||||||
|
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
|
||||||
|
|
||||||
|
const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size;
|
||||||
|
const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size;
|
||||||
|
|
||||||
|
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
|
||||||
|
LeftEvaluator, left_nocontract_t,
|
||||||
|
contract_t, lhs_packet_size,
|
||||||
|
lhs_inner_dim_contiguous,
|
||||||
|
false, Unaligned> LhsMapper;
|
||||||
|
|
||||||
|
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
|
||||||
|
RightEvaluator, right_nocontract_t,
|
||||||
|
contract_t, rhs_packet_size,
|
||||||
|
rhs_inner_dim_contiguous,
|
||||||
|
rhs_inner_dim_reordered, Unaligned> RhsMapper;
|
||||||
|
|
||||||
|
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
|
||||||
|
|
||||||
|
// Declare GEBP packing and kernel structs
|
||||||
|
internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress, ColMajor> pack_lhs;
|
||||||
|
internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, nr, ColMajor> pack_rhs;
|
||||||
|
|
||||||
|
internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, mr, nr, false, false> gebp;
|
||||||
|
|
||||||
|
// initialize data mappers
|
||||||
|
LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
|
||||||
|
this->m_left_contracting_strides, this->m_k_strides);
|
||||||
|
|
||||||
|
RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
|
||||||
|
this->m_right_contracting_strides, this->m_k_strides);
|
||||||
|
|
||||||
|
OutputMapper output(buffer, m);
|
||||||
|
|
||||||
|
// Sizes of the blocks to load in cache. See the Goto paper for details.
|
||||||
|
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1);
|
||||||
|
const Index kc = blocking.kc();
|
||||||
|
const Index mc = numext::mini(m, blocking.mc());
|
||||||
|
const Index nc = numext::mini(n, blocking.nc());
|
||||||
|
const Index sizeA = mc * kc;
|
||||||
|
const Index sizeB = kc * nc;
|
||||||
|
|
||||||
|
LhsScalar* blockA = static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar)));
|
||||||
|
RhsScalar* blockB = static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar)));
|
||||||
|
|
||||||
|
for(Index i2=0; i2<m; i2+=mc)
|
||||||
|
{
|
||||||
|
const Index actual_mc = numext::mini(i2+mc,m)-i2;
|
||||||
|
for (Index k2 = 0; k2 < k; k2 += kc) {
|
||||||
|
// make sure we don't overshoot right edge of left matrix, then pack vertical panel
|
||||||
|
const Index actual_kc = numext::mini(k2 + kc, k) - k2;
|
||||||
|
pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0);
|
||||||
|
|
||||||
|
// series of horizontal blocks
|
||||||
|
for (Index j2 = 0; j2 < n; j2 += nc) {
|
||||||
|
// make sure we don't overshoot right edge of right matrix, then pack block
|
||||||
|
const Index actual_nc = numext::mini(j2 + nc, n) - j2;
|
||||||
|
pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc, 0, 0);
|
||||||
|
|
||||||
|
// call gebp (matrix kernel)
|
||||||
|
// The parameters here are copied from Eigen's GEMM implementation
|
||||||
|
gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, 1.0, -1, -1, 0, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this->m_device.deallocate(blockA);
|
||||||
|
this->m_device.deallocate(blockB);
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
|
||||||
m_leftImpl.cleanup();
|
m_leftImpl.cleanup();
|
||||||
m_rightImpl.cleanup();
|
m_rightImpl.cleanup();
|
||||||
@ -533,100 +626,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
|
this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
|
||||||
}
|
|
||||||
|
|
||||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
|
||||||
EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const {
|
|
||||||
// columns in left side, rows in right side
|
|
||||||
const Index k = this->m_k_size;
|
|
||||||
|
|
||||||
// rows in left side
|
|
||||||
const Index m = this->m_i_size;
|
|
||||||
|
|
||||||
// columns in right side
|
|
||||||
const Index n = this->m_j_size;
|
|
||||||
|
|
||||||
// zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
|
|
||||||
this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
|
|
||||||
|
|
||||||
// define mr, nr, and all of my data mapper types
|
|
||||||
typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
|
|
||||||
typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
|
|
||||||
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
|
|
||||||
|
|
||||||
const Index nr = Traits::nr;
|
|
||||||
const Index mr = Traits::mr;
|
|
||||||
|
|
||||||
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
|
|
||||||
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
|
|
||||||
|
|
||||||
const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size;
|
|
||||||
const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size;
|
|
||||||
|
|
||||||
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
|
|
||||||
LeftEvaluator, left_nocontract_t,
|
|
||||||
contract_t, lhs_packet_size,
|
|
||||||
lhs_inner_dim_contiguous,
|
|
||||||
false, Unaligned> LhsMapper;
|
|
||||||
|
|
||||||
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
|
|
||||||
RightEvaluator, right_nocontract_t,
|
|
||||||
contract_t, rhs_packet_size,
|
|
||||||
rhs_inner_dim_contiguous,
|
|
||||||
rhs_inner_dim_reordered, Unaligned> RhsMapper;
|
|
||||||
|
|
||||||
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
|
|
||||||
|
|
||||||
// Declare GEBP packing and kernel structs
|
|
||||||
internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress, ColMajor> pack_lhs;
|
|
||||||
internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, nr, ColMajor> pack_rhs;
|
|
||||||
|
|
||||||
internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, mr, nr, false, false> gebp;
|
|
||||||
|
|
||||||
// initialize data mappers
|
|
||||||
LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
|
|
||||||
this->m_left_contracting_strides, this->m_k_strides);
|
|
||||||
|
|
||||||
RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
|
|
||||||
this->m_right_contracting_strides, this->m_k_strides);
|
|
||||||
|
|
||||||
OutputMapper output(buffer, m);
|
|
||||||
|
|
||||||
// Sizes of the blocks to load in cache. See the Goto paper for details.
|
|
||||||
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1);
|
|
||||||
const Index kc = blocking.kc();
|
|
||||||
const Index mc = numext::mini(m, blocking.mc());
|
|
||||||
const Index nc = numext::mini(n, blocking.nc());
|
|
||||||
const Index sizeA = mc * kc;
|
|
||||||
const Index sizeB = kc * nc;
|
|
||||||
|
|
||||||
LhsScalar* blockA = static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar)));
|
|
||||||
RhsScalar* blockB = static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar)));
|
|
||||||
|
|
||||||
for(Index i2=0; i2<m; i2+=mc)
|
|
||||||
{
|
|
||||||
const Index actual_mc = numext::mini(i2+mc,m)-i2;
|
|
||||||
for (Index k2 = 0; k2 < k; k2 += kc) {
|
|
||||||
// make sure we don't overshoot right edge of left matrix, then pack vertical panel
|
|
||||||
const Index actual_kc = numext::mini(k2 + kc, k) - k2;
|
|
||||||
pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0);
|
|
||||||
|
|
||||||
// series of horizontal blocks
|
|
||||||
for (Index j2 = 0; j2 < n; j2 += nc) {
|
|
||||||
// make sure we don't overshoot right edge of right matrix, then pack block
|
|
||||||
const Index actual_nc = numext::mini(j2 + nc, n) - j2;
|
|
||||||
pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc, 0, 0);
|
|
||||||
|
|
||||||
// call gebp (matrix kernel)
|
|
||||||
// The parameters here are copied from Eigen's GEMM implementation
|
|
||||||
gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, 1.0, -1, -1, 0, 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this->m_device.deallocate(blockA);
|
|
||||||
this->m_device.deallocate(blockB);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user