mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-24 14:45:14 +08:00
Fix expression evaluation heuristic for TensorSliceOp
This commit is contained in:
parent
23b958818e
commit
3cd148f983
@ -479,9 +479,12 @@ class TensorSlicingOp : public TensorBase<TensorSlicingOp<StartIndices, Sizes, X
|
||||
|
||||
// Fixme: figure out the exact threshold
|
||||
namespace {
|
||||
template <typename Index, typename Device> struct MemcpyTriggerForSlicing {
|
||||
template <typename Index, typename Device, bool BlockAccess> struct MemcpyTriggerForSlicing {
|
||||
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const Device& device) : threshold_(2 * device.numThreads()) { }
|
||||
EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > threshold_; }
|
||||
EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const {
|
||||
const bool prefer_block_evaluation = BlockAccess && total > 32*1024;
|
||||
return !prefer_block_evaluation && contiguous > threshold_;
|
||||
}
|
||||
|
||||
private:
|
||||
Index threshold_;
|
||||
@ -490,18 +493,18 @@ template <typename Index, typename Device> struct MemcpyTriggerForSlicing {
|
||||
// It is very expensive to start the memcpy kernel on GPU: we therefore only
|
||||
// use it for large copies.
|
||||
#ifdef EIGEN_USE_GPU
|
||||
template <typename Index> struct MemcpyTriggerForSlicing<Index, GpuDevice> {
|
||||
template <typename Index, bool BlockAccess> struct MemcpyTriggerForSlicing<Index, GpuDevice, BlockAccess> {
|
||||
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const GpuDevice&) { }
|
||||
EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > 4*1024*1024; }
|
||||
EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const { return contiguous > 4*1024*1024; }
|
||||
};
|
||||
#endif
|
||||
|
||||
// It is very expensive to start the memcpy kernel on GPU: we therefore only
|
||||
// use it for large copies.
|
||||
#ifdef EIGEN_USE_SYCL
|
||||
template <typename Index> struct MemcpyTriggerForSlicing<Index, Eigen::SyclDevice> {
|
||||
template <typename Index, bool BlockAccess> struct MemcpyTriggerForSlicing<Index, Eigen::SyclDevice, BlockAccess> {
|
||||
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const SyclDevice&) { }
|
||||
EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > 4*1024*1024; }
|
||||
EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const { return contiguous > 4*1024*1024; }
|
||||
};
|
||||
#endif
|
||||
|
||||
@ -592,8 +595,7 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
|
||||
m_impl.evalSubExprsIfNeeded(NULL);
|
||||
if (!NumTraits<typename internal::remove_const<Scalar>::type>::RequireInitialization
|
||||
&& data && m_impl.data()
|
||||
&& !BlockAccess) {
|
||||
&& data && m_impl.data()) {
|
||||
Index contiguous_values = 1;
|
||||
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
@ -611,8 +613,8 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
|
||||
}
|
||||
}
|
||||
// Use memcpy if it's going to be faster than using the regular evaluation.
|
||||
const MemcpyTriggerForSlicing<Index, Device> trigger(m_device);
|
||||
if (trigger(contiguous_values)) {
|
||||
const MemcpyTriggerForSlicing<Index, Device, BlockAccess> trigger(m_device);
|
||||
if (trigger(internal::array_prod(dimensions()), contiguous_values)) {
|
||||
EvaluatorPointerType src = (EvaluatorPointerType)m_impl.data();
|
||||
for (Index i = 0; i < internal::array_prod(dimensions()); i += contiguous_values) {
|
||||
Index offset = srcCoeff(i);
|
||||
|
Loading…
Reference in New Issue
Block a user