Fix bug in copy optimization in Tensor slicing.

This commit is contained in:
Eugene Zhulenev 2018-09-28 14:34:42 -07:00
parent 104e8fa074
commit bb13d5d917

View File

@ -979,44 +979,51 @@ struct TensorEvaluator<const TensorStridingSlicingOp<StartIndices, StopIndices,
}; };
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
: m_impl(op.expression(), device), m_device(device), m_strides(op.strides()), m_exprStartIndices(op.startIndices()), m_exprStopIndices(op.stopIndices()) : m_impl(op.expression(), device),
m_device(device),
m_strides(op.strides()), m_exprStartIndices(op.startIndices()),
m_exprStopIndices(op.stopIndices())
{ {
// Handle degenerate intervals by gracefully clamping and allowing m_dimensions to be zero // Handle degenerate intervals by gracefully clamping and allowing m_dimensions to be zero
DSizes<Index,NumDims> startIndicesClamped, stopIndicesClamped; DSizes<Index, NumDims> startIndicesClamped, stopIndicesClamped;
m_is_identity = true; for (ptrdiff_t i = 0; i < internal::array_size<Dimensions>::value; ++i) {
for (Index i = 0; i < internal::array_size<Dimensions>::value; ++i) {
if (m_strides[i] != 1 || op.startIndices()[i] != 0 ||
op.stopIndices()[i] != (m_impl.dimensions()[i] - 1)) {
m_is_identity = false;
}
eigen_assert(m_strides[i] != 0 && "0 stride is invalid"); eigen_assert(m_strides[i] != 0 && "0 stride is invalid");
if(m_strides[i]>0){ if (m_strides[i] > 0) {
startIndicesClamped[i] = clamp(op.startIndices()[i], 0, m_impl.dimensions()[i]); startIndicesClamped[i] =
stopIndicesClamped[i] = clamp(op.stopIndices()[i], 0, m_impl.dimensions()[i]); clamp(op.startIndices()[i], 0, m_impl.dimensions()[i]);
}else{ stopIndicesClamped[i] =
/* implies m_strides[i]<0 by assert */ clamp(op.stopIndices()[i], 0, m_impl.dimensions()[i]);
startIndicesClamped[i] = clamp(op.startIndices()[i], -1, m_impl.dimensions()[i] - 1); } else {
stopIndicesClamped[i] = clamp(op.stopIndices()[i], -1, m_impl.dimensions()[i] - 1); /* implies m_strides[i] < 0 by assert */
startIndicesClamped[i] =
clamp(op.startIndices()[i], -1, m_impl.dimensions()[i] - 1);
stopIndicesClamped[i] =
clamp(op.stopIndices()[i], -1, m_impl.dimensions()[i] - 1);
} }
m_startIndices[i] = startIndicesClamped[i]; m_startIndices[i] = startIndicesClamped[i];
} }
const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions(); typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
const InputDimensions& input_dims = m_impl.dimensions();
// check for degenerate intervals and compute output tensor shape // check for degenerate intervals and compute output tensor shape
bool degenerate = false;; bool degenerate = false;
for(int i = 0; i < NumDims; i++){ m_is_identity = true;
for (int i = 0; i < NumDims; i++) {
Index interval = stopIndicesClamped[i] - startIndicesClamped[i]; Index interval = stopIndicesClamped[i] - startIndicesClamped[i];
if(interval == 0 || ((interval<0) != (m_strides[i]<0))){ if (interval == 0 || ((interval < 0) != (m_strides[i] < 0))) {
m_dimensions[i] = 0; m_dimensions[i] = 0;
degenerate = true; degenerate = true;
}else{ } else {
m_dimensions[i] = interval / m_strides[i] m_dimensions[i] =
+ (interval % m_strides[i] != 0 ? 1 : 0); (interval / m_strides[i]) + (interval % m_strides[i] != 0 ? 1 : 0);
eigen_assert(m_dimensions[i] >= 0); eigen_assert(m_dimensions[i] >= 0);
} }
if (m_strides[i] != 1 || interval != m_impl.dimensions()[i]) {
m_is_identity = false;
}
} }
Strides output_dims = m_dimensions; Strides output_dims = m_dimensions;
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {