Can now use the tensor 'reverse' operation as a lvalue

This commit is contained in:
Benoit Steiner 2015-02-26 11:13:42 -08:00
parent 2fffe69b1b
commit 57154fdb32
3 changed files with 110 additions and 23 deletions

View File

@ -549,6 +549,11 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA
chip(const Index offset, const Index dim) const {
return TensorChippingOp<Dynamic, Derived>(derived(), offset, dim);
}
template <typename ReverseDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
TensorReverseOp<const ReverseDimensions, Derived>
reverse(const ReverseDimensions& rev) const {
return TensorReverseOp<const ReverseDimensions, Derived>(derived(), rev);
}
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
TensorShufflingOp<const Shuffle, Derived>
shuffle(const Shuffle& shuffle) const {

View File

@ -49,12 +49,9 @@ struct nested<TensorReverseOp<ReverseDimensions, XprType>, 1,
} // end namespace internal
template<typename ReverseDimensions, typename XprType>
class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
XprType>, ReadOnlyAccessors>
XprType>, WriteAccessors>
{
public:
typedef typename Eigen::internal::traits<TensorReverseOp>::Scalar Scalar;
@ -67,8 +64,8 @@ class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
StorageKind;
typedef typename Eigen::internal::traits<TensorReverseOp>::Index Index;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp(const XprType& expr,
const ReverseDimensions& reverse_dims)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp(
const XprType& expr, const ReverseDimensions& reverse_dims)
: m_xpr(expr), m_reverse_dims(reverse_dims) {}
EIGEN_DEVICE_FUNC
@ -78,12 +75,30 @@ class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
const typename internal::remove_all<typename XprType::Nested>::type&
expression() const { return m_xpr; }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TensorReverseOp& operator = (const TensorReverseOp& other)
{
typedef TensorAssignOp<TensorReverseOp, const TensorReverseOp> Assign;
Assign assign(*this, other);
internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
return *this;
}
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TensorReverseOp& operator = (const OtherDerived& other)
{
typedef TensorAssignOp<TensorReverseOp, const OtherDerived> Assign;
Assign assign(*this, other);
internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
return *this;
}
protected:
typename XprType::Nested m_xpr;
const ReverseDimensions m_reverse_dims;
};
// Eval as rvalue
template<typename ReverseDimensions, typename ArgType, typename Device>
struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device>
@ -134,8 +149,8 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
m_impl.cleanup();
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index reverseIndex(
Index index) const {
eigen_assert(index < dimensions().TotalSize());
Index inputIndex = 0;
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
@ -152,7 +167,6 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
} else {
inputIndex += index;
}
return m_impl.coeff(inputIndex);
} else {
for (int i = 0; i < NumDims - 1; ++i) {
Index idx = index / m_strides[i];
@ -167,8 +181,13 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
} else {
inputIndex += index;
}
return m_impl.coeff(inputIndex);
}
return inputIndex;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(
Index index) const {
return m_impl.coeff(reverseIndex(index));
}
template<int LoadMode>
@ -199,9 +218,57 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
ReverseDimensions m_reverse;
};
// Eval as lvalue
template <typename ReverseDimensions, typename ArgType, typename Device>
struct TensorEvaluator<TensorReverseOp<ReverseDimensions, ArgType>, Device>
: public TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>,
Device> {
typedef TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>,
Device> Base;
typedef TensorReverseOp<ReverseDimensions, ArgType> XprType;
typedef typename XprType::Index Index;
static const int NumDims = internal::array_size<ReverseDimensions>::value;
typedef DSizes<Index, NumDims> Dimensions;
enum {
IsAligned = false,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
Layout = TensorEvaluator<ArgType, Device>::Layout,
CoordAccess = false, // to be implemented
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op,
const Device& device)
: Base(op, device) {}
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions() const { return this->m_dimensions; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
return this->m_impl.coeffRef(this->reverseIndex(index));
}
template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void writePacket(Index index, const PacketReturnType& x) {
const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
eigen_assert(index+packetSize-1 < dimensions().TotalSize());
// This code is pilfered from TensorMorphing.h
EIGEN_ALIGN_DEFAULT CoeffReturnType values[packetSize];
internal::pstore<CoeffReturnType, PacketReturnType>(values, x);
for (int i = 0; i < packetSize; ++i) {
this->coeffRef(index+i) = values[i];
}
}
};
} // end namespace Eigen
} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_REVERSE_H

View File

@ -94,7 +94,7 @@ static void test_simple_reverse()
template <int DataLayout>
static void test_expr_reverse()
static void test_expr_reverse(bool LValue)
{
Tensor<float, 4, DataLayout> tensor(2,3,5,7);
tensor.setRandom();
@ -105,9 +105,12 @@ static void test_expr_reverse()
dim_rev[2] = false;
dim_rev[3] = true;
Tensor<float, 4, DataLayout> expected;
expected = tensor.reverse(dim_rev);
Tensor<float, 4, DataLayout> expected(2, 3, 5, 7);
if (LValue) {
expected.reverse(dim_rev) = tensor;
} else {
expected = tensor.reverse(dim_rev);
}
Tensor<float, 4, DataLayout> result(2,3,5,7);
@ -117,8 +120,13 @@ static void test_expr_reverse()
array<ptrdiff_t, 4> dst_slice_start{{0,0,0,0}};
for (int i = 0; i < 5; ++i) {
result.slice(dst_slice_start, dst_slice_dim) =
tensor.slice(src_slice_start, src_slice_dim).reverse(dim_rev);
if (LValue) {
result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) =
tensor.slice(src_slice_start, src_slice_dim);
} else {
result.slice(dst_slice_start, dst_slice_dim) =
tensor.slice(src_slice_start, src_slice_dim).reverse(dim_rev);
}
src_slice_start[2] += 1;
dst_slice_start[2] += 1;
}
@ -141,8 +149,13 @@ static void test_expr_reverse()
dst_slice_start[2] = 0;
result.setRandom();
for (int i = 0; i < 5; ++i) {
result.slice(dst_slice_start, dst_slice_dim) =
tensor.reverse(dim_rev).slice(dst_slice_start, dst_slice_dim);
if (LValue) {
result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) =
tensor.slice(dst_slice_start, dst_slice_dim);
} else {
result.slice(dst_slice_start, dst_slice_dim) =
tensor.reverse(dim_rev).slice(dst_slice_start, dst_slice_dim);
}
dst_slice_start[2] += 1;
}
@ -162,6 +175,8 @@ void test_cxx11_tensor_reverse()
{
CALL_SUBTEST(test_simple_reverse<ColMajor>());
CALL_SUBTEST(test_simple_reverse<RowMajor>());
CALL_SUBTEST(test_expr_reverse<ColMajor>());
CALL_SUBTEST(test_expr_reverse<RowMajor>());
CALL_SUBTEST(test_expr_reverse<ColMajor>(true));
CALL_SUBTEST(test_expr_reverse<RowMajor>(true));
CALL_SUBTEST(test_expr_reverse<ColMajor>(false));
CALL_SUBTEST(test_expr_reverse<RowMajor>(false));
}