Fix bug in copy optimization in Tensor slicing.

This commit is contained in:
Eugene Zhulenev 2018-09-28 14:34:42 -07:00
parent 104e8fa074
commit bb13d5d917

View File

@ -979,44 +979,51 @@ struct TensorEvaluator<const TensorStridingSlicingOp<StartIndices, StopIndices,
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
: m_impl(op.expression(), device), m_device(device), m_strides(op.strides()), m_exprStartIndices(op.startIndices()), m_exprStopIndices(op.stopIndices())
: m_impl(op.expression(), device),
m_device(device),
m_strides(op.strides()), m_exprStartIndices(op.startIndices()),
m_exprStopIndices(op.stopIndices())
{
// Handle degenerate intervals by gracefully clamping and allowing m_dimensions to be zero
DSizes<Index,NumDims> startIndicesClamped, stopIndicesClamped;
m_is_identity = true;
for (Index i = 0; i < internal::array_size<Dimensions>::value; ++i) {
if (m_strides[i] != 1 || op.startIndices()[i] != 0 ||
op.stopIndices()[i] != (m_impl.dimensions()[i] - 1)) {
m_is_identity = false;
}
DSizes<Index, NumDims> startIndicesClamped, stopIndicesClamped;
for (ptrdiff_t i = 0; i < internal::array_size<Dimensions>::value; ++i) {
eigen_assert(m_strides[i] != 0 && "0 stride is invalid");
if(m_strides[i]>0){
startIndicesClamped[i] = clamp(op.startIndices()[i], 0, m_impl.dimensions()[i]);
stopIndicesClamped[i] = clamp(op.stopIndices()[i], 0, m_impl.dimensions()[i]);
}else{
/* implies m_strides[i]<0 by assert */
startIndicesClamped[i] = clamp(op.startIndices()[i], -1, m_impl.dimensions()[i] - 1);
stopIndicesClamped[i] = clamp(op.stopIndices()[i], -1, m_impl.dimensions()[i] - 1);
if (m_strides[i] > 0) {
startIndicesClamped[i] =
clamp(op.startIndices()[i], 0, m_impl.dimensions()[i]);
stopIndicesClamped[i] =
clamp(op.stopIndices()[i], 0, m_impl.dimensions()[i]);
} else {
/* implies m_strides[i] < 0 by assert */
startIndicesClamped[i] =
clamp(op.startIndices()[i], -1, m_impl.dimensions()[i] - 1);
stopIndicesClamped[i] =
clamp(op.stopIndices()[i], -1, m_impl.dimensions()[i] - 1);
}
m_startIndices[i] = startIndicesClamped[i];
}
const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
const InputDimensions& input_dims = m_impl.dimensions();
// check for degenerate intervals and compute output tensor shape
bool degenerate = false;;
for(int i = 0; i < NumDims; i++){
bool degenerate = false;
m_is_identity = true;
for (int i = 0; i < NumDims; i++) {
Index interval = stopIndicesClamped[i] - startIndicesClamped[i];
if(interval == 0 || ((interval<0) != (m_strides[i]<0))){
if (interval == 0 || ((interval < 0) != (m_strides[i] < 0))) {
m_dimensions[i] = 0;
degenerate = true;
}else{
m_dimensions[i] = interval / m_strides[i]
+ (interval % m_strides[i] != 0 ? 1 : 0);
} else {
m_dimensions[i] =
(interval / m_strides[i]) + (interval % m_strides[i] != 0 ? 1 : 0);
eigen_assert(m_dimensions[i] >= 0);
}
if (m_strides[i] != 1 || interval != m_impl.dimensions()[i]) {
m_is_identity = false;
}
}
Strides output_dims = m_dimensions;
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {