mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-24 19:40:45 +08:00
Allow Tensor trace to be passed to a TensorRef.
This commit is contained in:
parent
8e32cbf7da
commit
6579e36eb4
unsupported
@ -27,6 +27,10 @@ struct traits<TensorTraceOp<Dims, XprType> > : public traits<XprType> {
|
||||
typedef std::remove_reference_t<Nested> Nested_;
|
||||
static constexpr int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
|
||||
static constexpr int Layout = XprTraits::Layout;
|
||||
enum {
|
||||
// Trace is read-only.
|
||||
Flags = traits<XprType>::Flags & ~LvalueBit
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Dims, typename XprType>
|
||||
@ -203,6 +207,8 @@ struct TensorEvaluator<const TensorTraceOp<Dims, ArgType>, Device> {
|
||||
return true;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return nullptr; }
|
||||
|
||||
EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
|
||||
|
@ -41,6 +41,9 @@ static void test_simple_lvalue_ref() {
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
VERIFY_IS_EQUAL(input(i), -i * 2);
|
||||
}
|
||||
|
||||
TensorRef<const Tensor<int, 1>> ref5(input.trace());
|
||||
VERIFY_IS_EQUAL(ref5[0], input[0]);
|
||||
}
|
||||
|
||||
static void test_simple_rvalue_ref() {
|
||||
|
Loading…
x
Reference in New Issue
Block a user