Fix TensorRef details

This commit is contained in:
Antonio Sánchez 2025-02-14 18:33:26 +00:00
parent 22cd7307dd
commit 9c211430b5
5 changed files with 146 additions and 69 deletions

View File

@ -27,6 +27,10 @@ struct traits<TensorBroadcastingOp<Broadcast, XprType>> : public traits<XprType>
static constexpr int NumDimensions = XprTraits::NumDimensions;
static constexpr int Layout = XprTraits::Layout;
typedef typename XprTraits::PointerType PointerType;
enum {
// Broadcast is read-only.
Flags = traits<XprType>::Flags & ~LvalueBit
};
};
template <typename Broadcast, typename XprType>

View File

@ -52,7 +52,13 @@ class TensorLazyEvaluatorReadOnly
typedef TensorEvaluator<Expr, Device> EvalType;
TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) {
m_dims = m_impl.dimensions();
EIGEN_STATIC_ASSERT(
internal::array_size<Dimensions>::value == internal::array_size<typename EvalType::Dimensions>::value,
"Dimension sizes must match.");
const auto& other_dims = m_impl.dimensions();
for (std::size_t i = 0; i < m_dims.size(); ++i) {
m_dims[i] = other_dims[i];
}
m_impl.evalSubExprsIfNeeded(NULL);
}
virtual ~TensorLazyEvaluatorReadOnly() { m_impl.cleanup(); }
@ -86,14 +92,12 @@ class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimension
EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) { return this->m_impl.coeffRef(index); }
};
template <typename Dimensions, typename Expr, typename Device>
class TensorLazyEvaluator : public std::conditional_t<bool(internal::is_lvalue<Expr>::value),
TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> > {
template <typename Dimensions, typename Expr, typename Device, bool IsWritable>
class TensorLazyEvaluator : public std::conditional_t<IsWritable, TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device>> {
public:
typedef std::conditional_t<bool(internal::is_lvalue<Expr>::value),
TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >
typedef std::conditional_t<IsWritable, TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device>>
Base;
typedef typename Base::Scalar Scalar;
@ -101,24 +105,15 @@ class TensorLazyEvaluator : public std::conditional_t<bool(internal::is_lvalue<E
virtual ~TensorLazyEvaluator() {}
};
} // namespace internal
/** \class TensorRef
* \ingroup CXX11_Tensor_Module
*
* \brief A reference to a tensor expression
* The expression will be evaluated lazily (as much as possible).
*
*/
template <typename PlainObjectType>
class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
template <typename Derived>
class TensorRefBase : public TensorBase<Derived> {
public:
typedef TensorRef<PlainObjectType> Self;
typedef typename traits<Derived>::PlainObjectType PlainObjectType;
typedef typename PlainObjectType::Base Base;
typedef typename Eigen::internal::nested<Self>::type Nested;
typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
typedef typename internal::traits<PlainObjectType>::Index Index;
typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
typedef typename Eigen::internal::nested<Derived>::type Nested;
typedef typename traits<PlainObjectType>::StorageKind StorageKind;
typedef typename traits<PlainObjectType>::Index Index;
typedef typename traits<PlainObjectType>::Scalar Scalar;
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef typename Base::CoeffReturnType CoeffReturnType;
typedef Scalar* PointerType;
@ -138,33 +133,17 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
};
//===- Tensor block evaluation strategy (see TensorBlock.h) -----------===//
typedef internal::TensorBlockNotImplemented TensorBlock;
typedef TensorBlockNotImplemented TensorBlock;
//===------------------------------------------------------------------===//
EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {}
EIGEN_STRONG_INLINE TensorRefBase() : m_evaluator(NULL) {}
template <typename Expression>
EIGEN_STRONG_INLINE TensorRef(const Expression& expr)
: m_evaluator(new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice())) {
m_evaluator->incrRefCount();
}
template <typename Expression>
EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) {
unrefEvaluator();
m_evaluator = new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice());
m_evaluator->incrRefCount();
return *this;
}
~TensorRef() { unrefEvaluator(); }
TensorRef(const TensorRef& other) : TensorBase<TensorRef<PlainObjectType> >(other), m_evaluator(other.m_evaluator) {
TensorRefBase(const TensorRefBase& other) : TensorBase<Derived>(other), m_evaluator(other.m_evaluator) {
eigen_assert(m_evaluator->refCount() > 0);
m_evaluator->incrRefCount();
}
TensorRef& operator=(const TensorRef& other) {
TensorRefBase& operator=(const TensorRefBase& other) {
if (this != &other) {
unrefEvaluator();
m_evaluator = other.m_evaluator;
@ -174,6 +153,28 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
return *this;
}
template <typename Expression,
typename EnableIf = std::enable_if_t<!std::is_same<std::decay_t<Expression>, Derived>::value>>
EIGEN_STRONG_INLINE TensorRefBase(const Expression& expr)
: m_evaluator(new TensorLazyEvaluator<Dimensions, Expression, DefaultDevice,
/*IsWritable=*/!std::is_const<PlainObjectType>::value &&
bool(is_lvalue<Expression>::value)>(expr, DefaultDevice())) {
m_evaluator->incrRefCount();
}
template <typename Expression,
typename EnableIf = std::enable_if_t<!std::is_same<std::decay_t<Expression>, Derived>::value>>
EIGEN_STRONG_INLINE TensorRefBase& operator=(const Expression& expr) {
unrefEvaluator();
m_evaluator = new TensorLazyEvaluator < Dimensions, Expression, DefaultDevice,
/*IsWritable=*/!std::is_const<PlainObjectType>::value&& bool(is_lvalue<Expression>::value) >
(expr, DefaultDevice());
m_evaluator->incrRefCount();
return *this;
}
~TensorRefBase() { unrefEvaluator(); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); }
@ -188,12 +189,6 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
return coeff(indices);
}
template <typename... IndexTypes>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) {
const std::size_t num_indices = (sizeof...(otherIndices) + 1);
const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
return coeffRef(indices);
}
template <std::size_t NumIndices>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const {
@ -212,6 +207,70 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
}
return m_evaluator->coeff(index);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const { return m_evaluator->coeff(index); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return m_evaluator->coeffRef(index); }
protected:
TensorLazyBaseEvaluator<Dimensions, Scalar>* evaluator() { return m_evaluator; }
private:
EIGEN_STRONG_INLINE void unrefEvaluator() {
if (m_evaluator) {
m_evaluator->decrRefCount();
if (m_evaluator->refCount() == 0) {
delete m_evaluator;
}
}
}
TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
};
} // namespace internal
/**
* \ingroup CXX11_Tensor_Module
*
* \brief A reference to a tensor expression
* The expression will be evaluated lazily (as much as possible).
*
*/
template <typename PlainObjectType>
class TensorRef : public internal::TensorRefBase<TensorRef<PlainObjectType>> {
typedef internal::TensorRefBase<TensorRef<PlainObjectType>> Base;
public:
using Scalar = typename Base::Scalar;
using Dimensions = typename Base::Dimensions;
EIGEN_STRONG_INLINE TensorRef() : Base() {}
template <typename Expression>
EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : Base(expr) {
EIGEN_STATIC_ASSERT(internal::is_lvalue<Expression>::value,
"Expression must be mutable to create a mutable TensorRef<Expression>. Did you mean "
"TensorRef<const Expression>?)");
}
template <typename Expression>
EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) {
EIGEN_STATIC_ASSERT(internal::is_lvalue<Expression>::value,
"Expression must be mutable to create a mutable TensorRef<Expression>. Did you mean "
"TensorRef<const Expression>?)");
return Base::operator=(expr).derived();
}
TensorRef& operator=(const TensorRef& other) { return Base::operator=(other).derived(); }
template <typename... IndexTypes>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) {
const std::size_t num_indices = (sizeof...(otherIndices) + 1);
const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
return coeffRef(indices);
}
template <std::size_t NumIndices>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices) {
const Dimensions& dims = this->dimensions();
@ -227,24 +286,37 @@ class TensorRef : public TensorBase<TensorRef<PlainObjectType> > {
index = index * dims[i] + indices[i];
}
}
return m_evaluator->coeffRef(index);
return Base::evaluator()->coeffRef(index);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const { return m_evaluator->coeff(index); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return Base::evaluator()->coeffRef(index); }
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return m_evaluator->coeffRef(index); }
/**
* \ingroup CXX11_Tensor_Module
*
* \brief A reference to a constant tensor expression
* The expression will be evaluated lazily (as much as possible).
*
*/
template <typename PlainObjectType>
class TensorRef<const PlainObjectType> : public internal::TensorRefBase<TensorRef<const PlainObjectType>> {
typedef internal::TensorRefBase<TensorRef<const PlainObjectType>> Base;
private:
EIGEN_STRONG_INLINE void unrefEvaluator() {
if (m_evaluator) {
m_evaluator->decrRefCount();
if (m_evaluator->refCount() == 0) {
delete m_evaluator;
}
}
public:
EIGEN_STRONG_INLINE TensorRef() : Base() {}
template <typename Expression>
EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : Base(expr) {}
template <typename Expression>
EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) {
return Base::operator=(expr).derived();
}
internal::TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
TensorRef(const TensorRef& other) : Base(other) {}
TensorRef& operator=(const TensorRef& other) { return Base::operator=(other).derived(); }
};
// evaluator for rvalues

View File

@ -95,8 +95,9 @@ struct traits<TensorMap<PlainObjectType, Options_, MakePointer_> > : public trai
typedef typename MakePointer<Scalar>::Type PointerType;
};
template <typename PlainObjectType>
struct traits<TensorRef<PlainObjectType> > : public traits<PlainObjectType> {
template <typename PlainObjectType_>
struct traits<TensorRef<PlainObjectType_> > : public traits<PlainObjectType_> {
typedef PlainObjectType_ PlainObjectType;
typedef traits<PlainObjectType> BaseTraits;
typedef typename BaseTraits::Scalar Scalar;
typedef typename BaseTraits::StorageKind StorageKind;

View File

@ -138,7 +138,7 @@ static void test_const_slice() {
TensorMap<Tensor<const T, 1>> m(b, 1);
DSizes<DenseIndex, 1> offsets;
offsets[0] = 0;
TensorRef<Tensor<const T, 1>> slice_ref(m.slice(offsets, m.dimensions()));
TensorRef<const Tensor<const T, 1>> slice_ref(m.slice(offsets, m.dimensions()));
VERIFY_IS_EQUAL(slice_ref(0), 42);
}

View File

@ -49,8 +49,8 @@ static void test_simple_rvalue_ref() {
Tensor<int, 1> input2(6);
input2.setRandom();
TensorRef<Tensor<int, 1>> ref3(input1 + input2);
TensorRef<Tensor<int, 1>> ref4 = input1 + input2;
TensorRef<const Tensor<int, 1>> ref3(input1 + input2);
TensorRef<const Tensor<int, 1>> ref4 = input1 + input2;
VERIFY_IS_NOT_EQUAL(ref3.data(), input1.data());
VERIFY_IS_NOT_EQUAL(ref4.data(), input1.data());
@ -144,7 +144,7 @@ static void test_ref_in_expr() {
Tensor<float, 3> result(3, 5, 7);
result.setRandom();
TensorRef<Tensor<float, 3>> result_ref(result);
TensorRef<const Tensor<float, 3>> result_ref(result);
Tensor<float, 3> bias(3, 5, 7);
bias.setRandom();
@ -192,7 +192,7 @@ static void test_nested_ops_with_ref() {
paddings[2] = std::make_pair(3, 4);
paddings[3] = std::make_pair(0, 0);
DSizes<Eigen::DenseIndex, 4> shuffle_dims(0, 1, 2, 3);
TensorRef<Tensor<const float, 4>> ref(m.pad(paddings));
TensorRef<const Tensor<const float, 4>> ref(m.pad(paddings));
array<std::pair<ptrdiff_t, ptrdiff_t>, 4> trivial;
trivial[0] = std::make_pair(0, 0);
trivial[1] = std::make_pair(0, 0);