mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-18 14:34:17 +08:00
Can now use the tensor 'reverse' operation as a lvalue
This commit is contained in:
parent
2fffe69b1b
commit
57154fdb32
@ -549,6 +549,11 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA
|
||||
chip(const Index offset, const Index dim) const {
|
||||
return TensorChippingOp<Dynamic, Derived>(derived(), offset, dim);
|
||||
}
|
||||
template <typename ReverseDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
TensorReverseOp<const ReverseDimensions, Derived>
|
||||
reverse(const ReverseDimensions& rev) const {
|
||||
return TensorReverseOp<const ReverseDimensions, Derived>(derived(), rev);
|
||||
}
|
||||
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
TensorShufflingOp<const Shuffle, Derived>
|
||||
shuffle(const Shuffle& shuffle) const {
|
||||
|
@ -49,12 +49,9 @@ struct nested<TensorReverseOp<ReverseDimensions, XprType>, 1,
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
||||
|
||||
|
||||
template<typename ReverseDimensions, typename XprType>
|
||||
class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
|
||||
XprType>, ReadOnlyAccessors>
|
||||
XprType>, WriteAccessors>
|
||||
{
|
||||
public:
|
||||
typedef typename Eigen::internal::traits<TensorReverseOp>::Scalar Scalar;
|
||||
@ -67,8 +64,8 @@ class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
|
||||
StorageKind;
|
||||
typedef typename Eigen::internal::traits<TensorReverseOp>::Index Index;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp(const XprType& expr,
|
||||
const ReverseDimensions& reverse_dims)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp(
|
||||
const XprType& expr, const ReverseDimensions& reverse_dims)
|
||||
: m_xpr(expr), m_reverse_dims(reverse_dims) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
@ -78,12 +75,30 @@ class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
|
||||
const typename internal::remove_all<typename XprType::Nested>::type&
|
||||
expression() const { return m_xpr; }
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE TensorReverseOp& operator = (const TensorReverseOp& other)
|
||||
{
|
||||
typedef TensorAssignOp<TensorReverseOp, const TensorReverseOp> Assign;
|
||||
Assign assign(*this, other);
|
||||
internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename OtherDerived>
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE TensorReverseOp& operator = (const OtherDerived& other)
|
||||
{
|
||||
typedef TensorAssignOp<TensorReverseOp, const OtherDerived> Assign;
|
||||
Assign assign(*this, other);
|
||||
internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
|
||||
return *this;
|
||||
}
|
||||
|
||||
protected:
|
||||
typename XprType::Nested m_xpr;
|
||||
const ReverseDimensions m_reverse_dims;
|
||||
};
|
||||
|
||||
|
||||
// Eval as rvalue
|
||||
template<typename ReverseDimensions, typename ArgType, typename Device>
|
||||
struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device>
|
||||
@ -134,8 +149,8 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
|
||||
m_impl.cleanup();
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
|
||||
{
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index reverseIndex(
|
||||
Index index) const {
|
||||
eigen_assert(index < dimensions().TotalSize());
|
||||
Index inputIndex = 0;
|
||||
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||
@ -152,7 +167,6 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
|
||||
} else {
|
||||
inputIndex += index;
|
||||
}
|
||||
return m_impl.coeff(inputIndex);
|
||||
} else {
|
||||
for (int i = 0; i < NumDims - 1; ++i) {
|
||||
Index idx = index / m_strides[i];
|
||||
@ -167,8 +181,13 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
|
||||
} else {
|
||||
inputIndex += index;
|
||||
}
|
||||
return m_impl.coeff(inputIndex);
|
||||
}
|
||||
return inputIndex;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(
|
||||
Index index) const {
|
||||
return m_impl.coeff(reverseIndex(index));
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
@ -199,9 +218,57 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
|
||||
ReverseDimensions m_reverse;
|
||||
};
|
||||
|
||||
// Eval as lvalue
|
||||
|
||||
template <typename ReverseDimensions, typename ArgType, typename Device>
|
||||
struct TensorEvaluator<TensorReverseOp<ReverseDimensions, ArgType>, Device>
|
||||
: public TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>,
|
||||
Device> {
|
||||
typedef TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>,
|
||||
Device> Base;
|
||||
typedef TensorReverseOp<ReverseDimensions, ArgType> XprType;
|
||||
typedef typename XprType::Index Index;
|
||||
static const int NumDims = internal::array_size<ReverseDimensions>::value;
|
||||
typedef DSizes<Index, NumDims> Dimensions;
|
||||
|
||||
enum {
|
||||
IsAligned = false,
|
||||
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
|
||||
Layout = TensorEvaluator<ArgType, Device>::Layout,
|
||||
CoordAccess = false, // to be implemented
|
||||
};
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op,
|
||||
const Device& device)
|
||||
: Base(op, device) {}
|
||||
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const Dimensions& dimensions() const { return this->m_dimensions; }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
|
||||
return this->m_impl.coeffRef(this->reverseIndex(index));
|
||||
}
|
||||
|
||||
template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
void writePacket(Index index, const PacketReturnType& x) {
|
||||
const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
|
||||
EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
||||
eigen_assert(index+packetSize-1 < dimensions().TotalSize());
|
||||
|
||||
// This code is pilfered from TensorMorphing.h
|
||||
EIGEN_ALIGN_DEFAULT CoeffReturnType values[packetSize];
|
||||
internal::pstore<CoeffReturnType, PacketReturnType>(values, x);
|
||||
for (int i = 0; i < packetSize; ++i) {
|
||||
this->coeffRef(index+i) = values[i];
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
} // end namespace Eigen
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_REVERSE_H
|
||||
|
@ -94,7 +94,7 @@ static void test_simple_reverse()
|
||||
|
||||
|
||||
template <int DataLayout>
|
||||
static void test_expr_reverse()
|
||||
static void test_expr_reverse(bool LValue)
|
||||
{
|
||||
Tensor<float, 4, DataLayout> tensor(2,3,5,7);
|
||||
tensor.setRandom();
|
||||
@ -105,9 +105,12 @@ static void test_expr_reverse()
|
||||
dim_rev[2] = false;
|
||||
dim_rev[3] = true;
|
||||
|
||||
|
||||
Tensor<float, 4, DataLayout> expected;
|
||||
expected = tensor.reverse(dim_rev);
|
||||
Tensor<float, 4, DataLayout> expected(2, 3, 5, 7);
|
||||
if (LValue) {
|
||||
expected.reverse(dim_rev) = tensor;
|
||||
} else {
|
||||
expected = tensor.reverse(dim_rev);
|
||||
}
|
||||
|
||||
Tensor<float, 4, DataLayout> result(2,3,5,7);
|
||||
|
||||
@ -117,8 +120,13 @@ static void test_expr_reverse()
|
||||
array<ptrdiff_t, 4> dst_slice_start{{0,0,0,0}};
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
result.slice(dst_slice_start, dst_slice_dim) =
|
||||
tensor.slice(src_slice_start, src_slice_dim).reverse(dim_rev);
|
||||
if (LValue) {
|
||||
result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) =
|
||||
tensor.slice(src_slice_start, src_slice_dim);
|
||||
} else {
|
||||
result.slice(dst_slice_start, dst_slice_dim) =
|
||||
tensor.slice(src_slice_start, src_slice_dim).reverse(dim_rev);
|
||||
}
|
||||
src_slice_start[2] += 1;
|
||||
dst_slice_start[2] += 1;
|
||||
}
|
||||
@ -141,8 +149,13 @@ static void test_expr_reverse()
|
||||
dst_slice_start[2] = 0;
|
||||
result.setRandom();
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
result.slice(dst_slice_start, dst_slice_dim) =
|
||||
tensor.reverse(dim_rev).slice(dst_slice_start, dst_slice_dim);
|
||||
if (LValue) {
|
||||
result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) =
|
||||
tensor.slice(dst_slice_start, dst_slice_dim);
|
||||
} else {
|
||||
result.slice(dst_slice_start, dst_slice_dim) =
|
||||
tensor.reverse(dim_rev).slice(dst_slice_start, dst_slice_dim);
|
||||
}
|
||||
dst_slice_start[2] += 1;
|
||||
}
|
||||
|
||||
@ -162,6 +175,8 @@ void test_cxx11_tensor_reverse()
|
||||
{
|
||||
CALL_SUBTEST(test_simple_reverse<ColMajor>());
|
||||
CALL_SUBTEST(test_simple_reverse<RowMajor>());
|
||||
CALL_SUBTEST(test_expr_reverse<ColMajor>());
|
||||
CALL_SUBTEST(test_expr_reverse<RowMajor>());
|
||||
CALL_SUBTEST(test_expr_reverse<ColMajor>(true));
|
||||
CALL_SUBTEST(test_expr_reverse<RowMajor>(true));
|
||||
CALL_SUBTEST(test_expr_reverse<ColMajor>(false));
|
||||
CALL_SUBTEST(test_expr_reverse<RowMajor>(false));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user