Added support for tensor slicing

This commit is contained in:
Benoit Steiner 2014-07-07 14:08:45 -07:00
parent 47981c5925
commit bc072c5cba

View File

@ -204,11 +204,16 @@ class TensorBase<Derived, ReadOnlyAccessors>
return TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>(derived(), thenTensor.derived(), elseTensor.derived());
}
// Morphing operators (slicing tbd).
// Morphing operators.
template <typename NewDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorReshapingOp<const Derived, const NewDimensions>
const TensorReshapingOp<const NewDimensions, const Derived>
reshape(const NewDimensions& newDimensions) const {
return TensorReshapingOp<const Derived, const NewDimensions>(derived(), newDimensions);
return TensorReshapingOp<const NewDimensions, const Derived>(derived(), newDimensions);
}
template <typename StartIndices, typename Sizes> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorSlicingOp<const StartIndices, const Sizes, const Derived>
slice(const StartIndices& startIndices, const Sizes& sizes) const {
return TensorSlicingOp<const StartIndices, const Sizes, const Derived>(derived(), startIndices, sizes);
}
// Force the evaluation of the expression.
@ -257,6 +262,17 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA
return derived() = TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
template <typename NewDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
TensorReshapingOp<const NewDimensions, Derived>
reshape(const NewDimensions& newDimensions) {
return TensorReshapingOp<const NewDimensions, Derived>(derived(), newDimensions);
}
template <typename StartIndices, typename Sizes> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
TensorSlicingOp<const StartIndices, const Sizes, Derived>
slice(const StartIndices& startIndices, const Sizes& sizes) const {
return TensorSlicingOp<const StartIndices, const Sizes, Derived>(derived(), startIndices, sizes);
}
// Select the device on which to evaluate the expression.
template <typename DeviceType>
TensorDevice<Derived, DeviceType> device(const DeviceType& device) {