mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-31 19:00:35 +08:00
Fix TensorRef details
This commit is contained in:
parent
22cd7307dd
commit
9c211430b5
@ -27,6 +27,10 @@ struct traits<TensorBroadcastingOp<Broadcast, XprType>> : public traits<XprType>
|
||||
static constexpr int NumDimensions = XprTraits::NumDimensions;
|
||||
static constexpr int Layout = XprTraits::Layout;
|
||||
typedef typename XprTraits::PointerType PointerType;
|
||||
enum {
|
||||
// Broadcast is read-only.
|
||||
Flags = traits<XprType>::Flags & ~LvalueBit
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Broadcast, typename XprType>
|
||||
|
@ -52,7 +52,13 @@ class TensorLazyEvaluatorReadOnly
|
||||
typedef TensorEvaluator<Expr, Device> EvalType;
|
||||
|
||||
TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) {
|
||||
m_dims = m_impl.dimensions();
|
||||
EIGEN_STATIC_ASSERT(
|
||||
internal::array_size<Dimensions>::value == internal::array_size<typename EvalType::Dimensions>::value,
|
||||
"Dimension sizes must match.");
|
||||
const auto& other_dims = m_impl.dimensions();
|
||||
for (std::size_t i = 0; i < m_dims.size(); ++i) {
|
||||
m_dims[i] = other_dims[i];
|
||||
}
|
||||
m_impl.evalSubExprsIfNeeded(NULL);
|
||||
}
|
||||
virtual ~TensorLazyEvaluatorReadOnly() { m_impl.cleanup(); }
|
||||
@ -86,14 +92,12 @@ class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimension
|
||||
EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) { return this->m_impl.coeffRef(index); }
|
||||
};
|
||||
|
||||
template <typename Dimensions, typename Expr, typename Device>
|
||||
class TensorLazyEvaluator : public std::conditional_t<bool(internal::is_lvalue<Expr>::value),
|
||||
TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
|
||||
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> > {
|
||||
template <typename Dimensions, typename Expr, typename Device, bool IsWritable>
|
||||
class TensorLazyEvaluator : public std::conditional_t<IsWritable, TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
|
||||
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device>> {
|
||||
public:
|
||||
typedef std::conditional_t<bool(internal::is_lvalue<Expr>::value),
|
||||
TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
|
||||
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >
|
||||
typedef std::conditional_t<IsWritable, TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
|
||||
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device>>
|
||||
Base;
|
||||
typedef typename Base::Scalar Scalar;
|
||||
|
||||
@ -101,24 +105,15 @@ class TensorLazyEvaluator : public std::conditional_t<bool(internal::is_lvalue<E
|
||||
virtual ~TensorLazyEvaluator() {}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
/** \class TensorRef
|
||||
* \ingroup CXX11_Tensor_Module
|
||||
*
|
||||
* \brief A reference to a tensor expression
|
||||
* The expression will be evaluated lazily (as much as possible).
|
||||
*
|
||||
*/
|
||||
template <typename PlainObjectType>
|
||||
class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
|
||||
template <typename Derived>
|
||||
class TensorRefBase : public TensorBase<Derived> {
|
||||
public:
|
||||
typedef TensorRef<PlainObjectType> Self;
|
||||
typedef typename traits<Derived>::PlainObjectType PlainObjectType;
|
||||
typedef typename PlainObjectType::Base Base;
|
||||
typedef typename Eigen::internal::nested<Self>::type Nested;
|
||||
typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
|
||||
typedef typename internal::traits<PlainObjectType>::Index Index;
|
||||
typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
|
||||
typedef typename Eigen::internal::nested<Derived>::type Nested;
|
||||
typedef typename traits<PlainObjectType>::StorageKind StorageKind;
|
||||
typedef typename traits<PlainObjectType>::Index Index;
|
||||
typedef typename traits<PlainObjectType>::Scalar Scalar;
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
typedef typename Base::CoeffReturnType CoeffReturnType;
|
||||
typedef Scalar* PointerType;
|
||||
@ -138,33 +133,17 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
|
||||
};
|
||||
|
||||
//===- Tensor block evaluation strategy (see TensorBlock.h) -----------===//
|
||||
typedef internal::TensorBlockNotImplemented TensorBlock;
|
||||
typedef TensorBlockNotImplemented TensorBlock;
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {}
|
||||
EIGEN_STRONG_INLINE TensorRefBase() : m_evaluator(NULL) {}
|
||||
|
||||
template <typename Expression>
|
||||
EIGEN_STRONG_INLINE TensorRef(const Expression& expr)
|
||||
: m_evaluator(new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice())) {
|
||||
m_evaluator->incrRefCount();
|
||||
}
|
||||
|
||||
template <typename Expression>
|
||||
EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) {
|
||||
unrefEvaluator();
|
||||
m_evaluator = new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice());
|
||||
m_evaluator->incrRefCount();
|
||||
return *this;
|
||||
}
|
||||
|
||||
~TensorRef() { unrefEvaluator(); }
|
||||
|
||||
TensorRef(const TensorRef& other) : TensorBase<TensorRef<PlainObjectType> >(other), m_evaluator(other.m_evaluator) {
|
||||
TensorRefBase(const TensorRefBase& other) : TensorBase<Derived>(other), m_evaluator(other.m_evaluator) {
|
||||
eigen_assert(m_evaluator->refCount() > 0);
|
||||
m_evaluator->incrRefCount();
|
||||
}
|
||||
|
||||
TensorRef& operator=(const TensorRef& other) {
|
||||
TensorRefBase& operator=(const TensorRefBase& other) {
|
||||
if (this != &other) {
|
||||
unrefEvaluator();
|
||||
m_evaluator = other.m_evaluator;
|
||||
@ -174,6 +153,28 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename Expression,
|
||||
typename EnableIf = std::enable_if_t<!std::is_same<std::decay_t<Expression>, Derived>::value>>
|
||||
EIGEN_STRONG_INLINE TensorRefBase(const Expression& expr)
|
||||
: m_evaluator(new TensorLazyEvaluator<Dimensions, Expression, DefaultDevice,
|
||||
/*IsWritable=*/!std::is_const<PlainObjectType>::value &&
|
||||
bool(is_lvalue<Expression>::value)>(expr, DefaultDevice())) {
|
||||
m_evaluator->incrRefCount();
|
||||
}
|
||||
|
||||
template <typename Expression,
|
||||
typename EnableIf = std::enable_if_t<!std::is_same<std::decay_t<Expression>, Derived>::value>>
|
||||
EIGEN_STRONG_INLINE TensorRefBase& operator=(const Expression& expr) {
|
||||
unrefEvaluator();
|
||||
m_evaluator = new TensorLazyEvaluator < Dimensions, Expression, DefaultDevice,
|
||||
/*IsWritable=*/!std::is_const<PlainObjectType>::value&& bool(is_lvalue<Expression>::value) >
|
||||
(expr, DefaultDevice());
|
||||
m_evaluator->incrRefCount();
|
||||
return *this;
|
||||
}
|
||||
|
||||
~TensorRefBase() { unrefEvaluator(); }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); }
|
||||
@ -188,12 +189,6 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
|
||||
const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
|
||||
return coeff(indices);
|
||||
}
|
||||
template <typename... IndexTypes>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) {
|
||||
const std::size_t num_indices = (sizeof...(otherIndices) + 1);
|
||||
const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
|
||||
return coeffRef(indices);
|
||||
}
|
||||
|
||||
template <std::size_t NumIndices>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const {
|
||||
@ -212,6 +207,70 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
|
||||
}
|
||||
return m_evaluator->coeff(index);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const { return m_evaluator->coeff(index); }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return m_evaluator->coeffRef(index); }
|
||||
|
||||
protected:
|
||||
TensorLazyBaseEvaluator<Dimensions, Scalar>* evaluator() { return m_evaluator; }
|
||||
|
||||
private:
|
||||
EIGEN_STRONG_INLINE void unrefEvaluator() {
|
||||
if (m_evaluator) {
|
||||
m_evaluator->decrRefCount();
|
||||
if (m_evaluator->refCount() == 0) {
|
||||
delete m_evaluator;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
/**
|
||||
* \ingroup CXX11_Tensor_Module
|
||||
*
|
||||
* \brief A reference to a tensor expression
|
||||
* The expression will be evaluated lazily (as much as possible).
|
||||
*
|
||||
*/
|
||||
template <typename PlainObjectType>
|
||||
class TensorRef : public internal::TensorRefBase<TensorRef<PlainObjectType>> {
|
||||
typedef internal::TensorRefBase<TensorRef<PlainObjectType>> Base;
|
||||
|
||||
public:
|
||||
using Scalar = typename Base::Scalar;
|
||||
using Dimensions = typename Base::Dimensions;
|
||||
|
||||
EIGEN_STRONG_INLINE TensorRef() : Base() {}
|
||||
|
||||
template <typename Expression>
|
||||
EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : Base(expr) {
|
||||
EIGEN_STATIC_ASSERT(internal::is_lvalue<Expression>::value,
|
||||
"Expression must be mutable to create a mutable TensorRef<Expression>. Did you mean "
|
||||
"TensorRef<const Expression>?)");
|
||||
}
|
||||
|
||||
template <typename Expression>
|
||||
EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) {
|
||||
EIGEN_STATIC_ASSERT(internal::is_lvalue<Expression>::value,
|
||||
"Expression must be mutable to create a mutable TensorRef<Expression>. Did you mean "
|
||||
"TensorRef<const Expression>?)");
|
||||
return Base::operator=(expr).derived();
|
||||
}
|
||||
|
||||
TensorRef& operator=(const TensorRef& other) { return Base::operator=(other).derived(); }
|
||||
|
||||
template <typename... IndexTypes>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) {
|
||||
const std::size_t num_indices = (sizeof...(otherIndices) + 1);
|
||||
const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
|
||||
return coeffRef(indices);
|
||||
}
|
||||
|
||||
template <std::size_t NumIndices>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices) {
|
||||
const Dimensions& dims = this->dimensions();
|
||||
@ -227,24 +286,37 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
|
||||
index = index * dims[i] + indices[i];
|
||||
}
|
||||
}
|
||||
return m_evaluator->coeffRef(index);
|
||||
return Base::evaluator()->coeffRef(index);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const { return m_evaluator->coeff(index); }
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return Base::evaluator()->coeffRef(index); }
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return m_evaluator->coeffRef(index); }
|
||||
/**
|
||||
* \ingroup CXX11_Tensor_Module
|
||||
*
|
||||
* \brief A reference to a constant tensor expression
|
||||
* The expression will be evaluated lazily (as much as possible).
|
||||
*
|
||||
*/
|
||||
template <typename PlainObjectType>
|
||||
class TensorRef<const PlainObjectType> : public internal::TensorRefBase<TensorRef<const PlainObjectType>> {
|
||||
typedef internal::TensorRefBase<TensorRef<const PlainObjectType>> Base;
|
||||
|
||||
private:
|
||||
EIGEN_STRONG_INLINE void unrefEvaluator() {
|
||||
if (m_evaluator) {
|
||||
m_evaluator->decrRefCount();
|
||||
if (m_evaluator->refCount() == 0) {
|
||||
delete m_evaluator;
|
||||
}
|
||||
}
|
||||
public:
|
||||
EIGEN_STRONG_INLINE TensorRef() : Base() {}
|
||||
|
||||
template <typename Expression>
|
||||
EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : Base(expr) {}
|
||||
|
||||
template <typename Expression>
|
||||
EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) {
|
||||
return Base::operator=(expr).derived();
|
||||
}
|
||||
|
||||
internal::TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
|
||||
TensorRef(const TensorRef& other) : Base(other) {}
|
||||
|
||||
TensorRef& operator=(const TensorRef& other) { return Base::operator=(other).derived(); }
|
||||
};
|
||||
|
||||
// evaluator for rvalues
|
||||
|
@ -95,8 +95,9 @@ struct traits<TensorMap<PlainObjectType, Options_, MakePointer_> > : public trai
|
||||
typedef typename MakePointer<Scalar>::Type PointerType;
|
||||
};
|
||||
|
||||
template <typename PlainObjectType>
|
||||
struct traits<TensorRef<PlainObjectType> > : public traits<PlainObjectType> {
|
||||
template <typename PlainObjectType_>
|
||||
struct traits<TensorRef<PlainObjectType_> > : public traits<PlainObjectType_> {
|
||||
typedef PlainObjectType_ PlainObjectType;
|
||||
typedef traits<PlainObjectType> BaseTraits;
|
||||
typedef typename BaseTraits::Scalar Scalar;
|
||||
typedef typename BaseTraits::StorageKind StorageKind;
|
||||
|
@ -138,7 +138,7 @@ static void test_const_slice() {
|
||||
TensorMap<Tensor<const T, 1>> m(b, 1);
|
||||
DSizes<DenseIndex, 1> offsets;
|
||||
offsets[0] = 0;
|
||||
TensorRef<Tensor<const T, 1>> slice_ref(m.slice(offsets, m.dimensions()));
|
||||
TensorRef<const Tensor<const T, 1>> slice_ref(m.slice(offsets, m.dimensions()));
|
||||
VERIFY_IS_EQUAL(slice_ref(0), 42);
|
||||
}
|
||||
|
||||
|
@ -49,8 +49,8 @@ static void test_simple_rvalue_ref() {
|
||||
Tensor<int, 1> input2(6);
|
||||
input2.setRandom();
|
||||
|
||||
TensorRef<Tensor<int, 1>> ref3(input1 + input2);
|
||||
TensorRef<Tensor<int, 1>> ref4 = input1 + input2;
|
||||
TensorRef<const Tensor<int, 1>> ref3(input1 + input2);
|
||||
TensorRef<const Tensor<int, 1>> ref4 = input1 + input2;
|
||||
|
||||
VERIFY_IS_NOT_EQUAL(ref3.data(), input1.data());
|
||||
VERIFY_IS_NOT_EQUAL(ref4.data(), input1.data());
|
||||
@ -144,7 +144,7 @@ static void test_ref_in_expr() {
|
||||
|
||||
Tensor<float, 3> result(3, 5, 7);
|
||||
result.setRandom();
|
||||
TensorRef<Tensor<float, 3>> result_ref(result);
|
||||
TensorRef<const Tensor<float, 3>> result_ref(result);
|
||||
|
||||
Tensor<float, 3> bias(3, 5, 7);
|
||||
bias.setRandom();
|
||||
@ -192,7 +192,7 @@ static void test_nested_ops_with_ref() {
|
||||
paddings[2] = std::make_pair(3, 4);
|
||||
paddings[3] = std::make_pair(0, 0);
|
||||
DSizes<Eigen::DenseIndex, 4> shuffle_dims(0, 1, 2, 3);
|
||||
TensorRef<Tensor<const float, 4>> ref(m.pad(paddings));
|
||||
TensorRef<const Tensor<const float, 4>> ref(m.pad(paddings));
|
||||
array<std::pair<ptrdiff_t, ptrdiff_t>, 4> trivial;
|
||||
trivial[0] = std::make_pair(0, 0);
|
||||
trivial[1] = std::make_pair(0, 0);
|
||||
|
Loading…
x
Reference in New Issue
Block a user