From 96aeffb0131e6e0df44bd376b3db693392cdf9c7 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Wed, 17 Nov 2021 18:16:04 -0800 Subject: [PATCH] Make the new TensorIO implementation work with TensorMap with const elements. --- unsupported/Eigen/CXX11/src/Tensor/TensorIO.h | 4 ++-- unsupported/test/cxx11_tensor_io.cpp | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h b/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h index 3e95f95fa..b958f6abb 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h @@ -186,7 +186,7 @@ namespace internal { template struct TensorPrinter { static void run(std::ostream& s, const Tensor& _t, const TensorIOFormat& fmt) { - typedef typename Tensor::Scalar Scalar; + typedef typename internal::remove_const::type Scalar; typedef typename Tensor::Index Index; static const int layout = Tensor::Layout; // backwards compatibility case: print tensor after reshaping to matrix of size dim(0) x @@ -195,7 +195,7 @@ struct TensorPrinter { const Index total_size = internal::array_prod(_t.dimensions()); if (total_size > 0) { const Index first_dim = Eigen::internal::array_get<0>(_t.dimensions()); - Map > matrix(const_cast(_t.data()), first_dim, + Map > matrix(_t.data(), first_dim, total_size / first_dim); s << matrix; return; diff --git a/unsupported/test/cxx11_tensor_io.cpp b/unsupported/test/cxx11_tensor_io.cpp index 34fcbc3ff..c0ea2a101 100644 --- a/unsupported/test/cxx11_tensor_io.cpp +++ b/unsupported/test/cxx11_tensor_io.cpp @@ -90,6 +90,16 @@ void test_tensor_ostream() { test_tensor_ostream_impl::run(); } +void test_const_tensor_ostream() { + Eigen::Tensor t; + t.setValues(1); + const Eigen::TensorMap, Eigen::Unaligned> t_const( + t.data(), Eigen::DSizes{}); + std::ostringstream os; + os << t_const.format(Eigen::TensorIOFormat::Plain()); + VERIFY(os.str() == "1"); +} + EIGEN_DECLARE_TEST(cxx11_tensor_io) { CALL_SUBTEST((test_tensor_ostream())); CALL_SUBTEST((test_tensor_ostream())); @@ -126,4 +136,7 @@ EIGEN_DECLARE_TEST(cxx11_tensor_io) { CALL_SUBTEST((test_tensor_ostream, 2, Eigen::ColMajor>())); CALL_SUBTEST((test_tensor_ostream, 2, Eigen::ColMajor>())); + + // Test printing TensorMap with const elements. + CALL_SUBTEST((test_const_tensor_ostream())); }