mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-02-23 18:20:47 +08:00
Added support for evaluation of tensor shuffling operations as lvalues
This commit is contained in:
parent
f50548e86a
commit
d43f737b4a
@ -222,19 +222,19 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
return TensorSlicingOp<const StartIndices, const Sizes, const Derived>(derived(), startIndices, sizes);
|
||||
}
|
||||
template <typename PaddingDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
TensorPaddingOp<const PaddingDimensions, Derived>
|
||||
const TensorPaddingOp<const PaddingDimensions, const Derived>
|
||||
pad(const PaddingDimensions& padding) const {
|
||||
return TensorPaddingOp<const PaddingDimensions, Derived>(derived(), padding);
|
||||
return TensorPaddingOp<const PaddingDimensions, const Derived>(derived(), padding);
|
||||
}
|
||||
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
TensorShufflingOp<const Shuffle, Derived>
|
||||
const TensorShufflingOp<const Shuffle, const Derived>
|
||||
shuffle(const Shuffle& shuffle) const {
|
||||
return TensorShufflingOp<const Shuffle, Derived>(derived(), shuffle);
|
||||
return TensorShufflingOp<const Shuffle, const Derived>(derived(), shuffle);
|
||||
}
|
||||
template <typename Strides> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
TensorStridingOp<const Strides, Derived>
|
||||
const TensorStridingOp<const Strides, const Derived>
|
||||
stride(const Strides& strides) const {
|
||||
return TensorStridingOp<const Strides, Derived>(derived(), strides);
|
||||
return TensorStridingOp<const Strides, const Derived>(derived(), strides);
|
||||
}
|
||||
|
||||
// Force the evaluation of the expression.
|
||||
@ -244,6 +244,7 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename Scalar, std::size_t NumIndices, int Options> friend class Tensor;
|
||||
template <typename OtherDerived, int AccessLevel> friend class TensorBase;
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast<const Derived*>(this); }
|
||||
@ -258,6 +259,7 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA
|
||||
typedef Scalar CoeffReturnType;
|
||||
typedef typename internal::packet_traits<Scalar>::type PacketReturnType;
|
||||
|
||||
template <typename Scalar, std::size_t NumIndices, int Options> friend class Tensor;
|
||||
template <typename OtherDerived, int AccessLevel> friend class TensorBase;
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
@ -293,6 +295,11 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA
|
||||
slice(const StartIndices& startIndices, const Sizes& sizes) const {
|
||||
return TensorSlicingOp<const StartIndices, const Sizes, Derived>(derived(), startIndices, sizes);
|
||||
}
|
||||
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
TensorShufflingOp<const Shuffle, Derived>
|
||||
shuffle(const Shuffle& shuffle) const {
|
||||
return TensorShufflingOp<const Shuffle, Derived>(derived(), shuffle);
|
||||
}
|
||||
|
||||
// Select the device on which to evaluate the expression.
|
||||
template <typename DeviceType>
|
||||
|
@ -48,7 +48,7 @@ struct nested<TensorShufflingOp<Shuffle, XprType>, 1, typename eval<TensorShuffl
|
||||
|
||||
|
||||
template<typename Shuffle, typename XprType>
|
||||
class TensorShufflingOp : public TensorBase<TensorShufflingOp<Shuffle, XprType>, WriteAccessors>
|
||||
class TensorShufflingOp : public TensorBase<TensorShufflingOp<Shuffle, XprType> >
|
||||
{
|
||||
public:
|
||||
typedef typename Eigen::internal::traits<TensorShufflingOp>::Scalar Scalar;
|
||||
@ -94,33 +94,38 @@ struct TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device>
|
||||
typedef typename XprType::Index Index;
|
||||
static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
|
||||
typedef DSizes<Index, NumDims> Dimensions;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
|
||||
enum {
|
||||
IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/false,
|
||||
PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/false,
|
||||
IsAligned = true,
|
||||
PacketAccess = (internal::packet_traits<Scalar>::size > 1),
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
||||
: m_impl(op.expression(), device), m_shuffle(op.shuffle())
|
||||
: m_impl(op.expression(), device)
|
||||
{
|
||||
const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
|
||||
const Shuffle& shuffle = op.shuffle();
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
m_dimensions[i] = input_dims[m_shuffle[i]];
|
||||
m_dimensions[i] = input_dims[shuffle[i]];
|
||||
}
|
||||
|
||||
array<Index, NumDims> inputStrides;
|
||||
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
if (i > 0) {
|
||||
m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
|
||||
inputStrides[i] = inputStrides[i-1] * input_dims[i-1];
|
||||
m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
|
||||
} else {
|
||||
m_inputStrides[0] = 1;
|
||||
inputStrides[0] = 1;
|
||||
m_outputStrides[0] = 1;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
m_inputStrides[i] = inputStrides[shuffle[i]];
|
||||
}
|
||||
}
|
||||
|
||||
// typedef typename XprType::Index Index;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
|
||||
@ -136,33 +141,90 @@ struct TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device>
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
|
||||
{
|
||||
Index inputIndex = 0;
|
||||
for (int i = NumDims - 1; i > 0; --i) {
|
||||
const Index idx = index / m_outputStrides[i];
|
||||
inputIndex += idx * m_inputStrides[m_shuffle[i]];
|
||||
index -= idx * m_outputStrides[i];
|
||||
}
|
||||
inputIndex += index * m_inputStrides[m_shuffle[0]];
|
||||
return m_impl.coeff(inputIndex);
|
||||
return m_impl.coeff(srcCoeff(index));
|
||||
}
|
||||
|
||||
/* template<int LoadMode>
|
||||
template<int LoadMode>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
|
||||
{
|
||||
return m_impl.template packet<LoadMode>(index);
|
||||
}*/
|
||||
const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
|
||||
EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
||||
eigen_assert(index+packetSize-1 < dimensions().TotalSize());
|
||||
|
||||
EIGEN_ALIGN_DEFAULT typename internal::remove_const<CoeffReturnType>::type values[packetSize];
|
||||
for (int i = 0; i < packetSize; ++i) {
|
||||
values[i] = coeff(index+i);
|
||||
}
|
||||
PacketReturnType rslt = internal::pload<PacketReturnType>(values);
|
||||
return rslt;
|
||||
}
|
||||
|
||||
Scalar* data() const { return NULL; }
|
||||
|
||||
protected:
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index srcCoeff(Index index) const
|
||||
{
|
||||
Index inputIndex = 0;
|
||||
for (int i = NumDims - 1; i > 0; --i) {
|
||||
const Index idx = index / m_outputStrides[i];
|
||||
inputIndex += idx * m_inputStrides[i];
|
||||
index -= idx * m_outputStrides[i];
|
||||
}
|
||||
return inputIndex + index * m_inputStrides[0];
|
||||
}
|
||||
|
||||
Dimensions m_dimensions;
|
||||
Shuffle m_shuffle;
|
||||
array<Index, NumDims> m_outputStrides;
|
||||
array<Index, NumDims> m_inputStrides;
|
||||
TensorEvaluator<ArgType, Device> m_impl;
|
||||
};
|
||||
|
||||
|
||||
// Eval as lvalue
|
||||
template<typename Shuffle, typename ArgType, typename Device>
|
||||
struct TensorEvaluator<TensorShufflingOp<Shuffle, ArgType>, Device>
|
||||
: public TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device>
|
||||
{
|
||||
typedef TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device> Base;
|
||||
|
||||
typedef TensorShufflingOp<Shuffle, ArgType> XprType;
|
||||
typedef typename XprType::Index Index;
|
||||
static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
|
||||
typedef DSizes<Index, NumDims> Dimensions;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
|
||||
enum {
|
||||
IsAligned = true,
|
||||
PacketAccess = (internal::packet_traits<Scalar>::size > 1),
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
||||
: Base(op, device)
|
||||
{ }
|
||||
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index)
|
||||
{
|
||||
return this->m_impl.coeffRef(this->srcCoeff(index));
|
||||
}
|
||||
|
||||
template <int StoreMode> EIGEN_STRONG_INLINE
|
||||
void writePacket(Index index, const PacketReturnType& x)
|
||||
{
|
||||
static const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
|
||||
EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
||||
|
||||
EIGEN_ALIGN_DEFAULT typename internal::remove_const<CoeffReturnType>::type values[packetSize];
|
||||
internal::pstore<CoeffReturnType, PacketReturnType>(values, x);
|
||||
for (int i = 0; i < packetSize; ++i) {
|
||||
this->coeffRef(index+i) = values[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_SHUFFLING_H
|
||||
|
Loading…
Reference in New Issue
Block a user