mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Fix bug in copy optimization in Tensor slicing.
This commit is contained in:
parent
104e8fa074
commit
bb13d5d917
@ -979,44 +979,51 @@ struct TensorEvaluator<const TensorStridingSlicingOp<StartIndices, StopIndices,
|
|||||||
};
|
};
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
||||||
: m_impl(op.expression(), device), m_device(device), m_strides(op.strides()), m_exprStartIndices(op.startIndices()), m_exprStopIndices(op.stopIndices())
|
: m_impl(op.expression(), device),
|
||||||
|
m_device(device),
|
||||||
|
m_strides(op.strides()), m_exprStartIndices(op.startIndices()),
|
||||||
|
m_exprStopIndices(op.stopIndices())
|
||||||
{
|
{
|
||||||
// Handle degenerate intervals by gracefully clamping and allowing m_dimensions to be zero
|
// Handle degenerate intervals by gracefully clamping and allowing m_dimensions to be zero
|
||||||
DSizes<Index,NumDims> startIndicesClamped, stopIndicesClamped;
|
DSizes<Index, NumDims> startIndicesClamped, stopIndicesClamped;
|
||||||
m_is_identity = true;
|
for (ptrdiff_t i = 0; i < internal::array_size<Dimensions>::value; ++i) {
|
||||||
for (Index i = 0; i < internal::array_size<Dimensions>::value; ++i) {
|
|
||||||
if (m_strides[i] != 1 || op.startIndices()[i] != 0 ||
|
|
||||||
op.stopIndices()[i] != (m_impl.dimensions()[i] - 1)) {
|
|
||||||
m_is_identity = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
eigen_assert(m_strides[i] != 0 && "0 stride is invalid");
|
eigen_assert(m_strides[i] != 0 && "0 stride is invalid");
|
||||||
if(m_strides[i]>0){
|
if (m_strides[i] > 0) {
|
||||||
startIndicesClamped[i] = clamp(op.startIndices()[i], 0, m_impl.dimensions()[i]);
|
startIndicesClamped[i] =
|
||||||
stopIndicesClamped[i] = clamp(op.stopIndices()[i], 0, m_impl.dimensions()[i]);
|
clamp(op.startIndices()[i], 0, m_impl.dimensions()[i]);
|
||||||
}else{
|
stopIndicesClamped[i] =
|
||||||
/* implies m_strides[i]<0 by assert */
|
clamp(op.stopIndices()[i], 0, m_impl.dimensions()[i]);
|
||||||
startIndicesClamped[i] = clamp(op.startIndices()[i], -1, m_impl.dimensions()[i] - 1);
|
} else {
|
||||||
stopIndicesClamped[i] = clamp(op.stopIndices()[i], -1, m_impl.dimensions()[i] - 1);
|
/* implies m_strides[i] < 0 by assert */
|
||||||
|
startIndicesClamped[i] =
|
||||||
|
clamp(op.startIndices()[i], -1, m_impl.dimensions()[i] - 1);
|
||||||
|
stopIndicesClamped[i] =
|
||||||
|
clamp(op.stopIndices()[i], -1, m_impl.dimensions()[i] - 1);
|
||||||
}
|
}
|
||||||
m_startIndices[i] = startIndicesClamped[i];
|
m_startIndices[i] = startIndicesClamped[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
|
typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
|
||||||
|
const InputDimensions& input_dims = m_impl.dimensions();
|
||||||
|
|
||||||
// check for degenerate intervals and compute output tensor shape
|
// check for degenerate intervals and compute output tensor shape
|
||||||
bool degenerate = false;;
|
bool degenerate = false;
|
||||||
for(int i = 0; i < NumDims; i++){
|
m_is_identity = true;
|
||||||
|
for (int i = 0; i < NumDims; i++) {
|
||||||
Index interval = stopIndicesClamped[i] - startIndicesClamped[i];
|
Index interval = stopIndicesClamped[i] - startIndicesClamped[i];
|
||||||
if(interval == 0 || ((interval<0) != (m_strides[i]<0))){
|
if (interval == 0 || ((interval < 0) != (m_strides[i] < 0))) {
|
||||||
m_dimensions[i] = 0;
|
m_dimensions[i] = 0;
|
||||||
degenerate = true;
|
degenerate = true;
|
||||||
}else{
|
} else {
|
||||||
m_dimensions[i] = interval / m_strides[i]
|
m_dimensions[i] =
|
||||||
+ (interval % m_strides[i] != 0 ? 1 : 0);
|
(interval / m_strides[i]) + (interval % m_strides[i] != 0 ? 1 : 0);
|
||||||
eigen_assert(m_dimensions[i] >= 0);
|
eigen_assert(m_dimensions[i] >= 0);
|
||||||
}
|
}
|
||||||
|
if (m_strides[i] != 1 || interval != m_impl.dimensions()[i]) {
|
||||||
|
m_is_identity = false;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Strides output_dims = m_dimensions;
|
Strides output_dims = m_dimensions;
|
||||||
|
|
||||||
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||||
|
Loading…
Reference in New Issue
Block a user