mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-24 19:40:45 +08:00
Allow Tensor trace to be passed to a TensorRef.
This commit is contained in:
parent
8e32cbf7da
commit
6579e36eb4
@ -27,6 +27,10 @@ struct traits<TensorTraceOp<Dims, XprType> > : public traits<XprType> {
|
|||||||
typedef std::remove_reference_t<Nested> Nested_;
|
typedef std::remove_reference_t<Nested> Nested_;
|
||||||
static constexpr int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
|
static constexpr int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
|
||||||
static constexpr int Layout = XprTraits::Layout;
|
static constexpr int Layout = XprTraits::Layout;
|
||||||
|
enum {
|
||||||
|
// Trace is read-only.
|
||||||
|
Flags = traits<XprType>::Flags & ~LvalueBit
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Dims, typename XprType>
|
template <typename Dims, typename XprType>
|
||||||
@ -203,6 +207,8 @@ struct TensorEvaluator<const TensorTraceOp<Dims, ArgType>, Device> {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return nullptr; }
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
|
EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
|
||||||
|
@ -41,6 +41,9 @@ static void test_simple_lvalue_ref() {
|
|||||||
for (int i = 0; i < 6; ++i) {
|
for (int i = 0; i < 6; ++i) {
|
||||||
VERIFY_IS_EQUAL(input(i), -i * 2);
|
VERIFY_IS_EQUAL(input(i), -i * 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TensorRef<const Tensor<int, 1>> ref5(input.trace());
|
||||||
|
VERIFY_IS_EQUAL(ref5[0], input[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_simple_rvalue_ref() {
|
static void test_simple_rvalue_ref() {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user