Make the new TensorIO implementation work with TensorMap with const elements.

This commit is contained in:
Rasmus Munk Larsen 2021-11-17 18:16:04 -08:00
parent 824d06eb36
commit 96aeffb013
2 changed files with 15 additions and 2 deletions

View File

@ -186,7 +186,7 @@ namespace internal {
template <typename Tensor, std::size_t rank> template <typename Tensor, std::size_t rank>
struct TensorPrinter { struct TensorPrinter {
static void run(std::ostream& s, const Tensor& _t, const TensorIOFormat& fmt) { static void run(std::ostream& s, const Tensor& _t, const TensorIOFormat& fmt) {
typedef typename Tensor::Scalar Scalar; typedef typename internal::remove_const<typename Tensor::Scalar>::type Scalar;
typedef typename Tensor::Index Index; typedef typename Tensor::Index Index;
static const int layout = Tensor::Layout; static const int layout = Tensor::Layout;
// backwards compatibility case: print tensor after reshaping to matrix of size dim(0) x // backwards compatibility case: print tensor after reshaping to matrix of size dim(0) x
@ -195,7 +195,7 @@ struct TensorPrinter {
const Index total_size = internal::array_prod(_t.dimensions()); const Index total_size = internal::array_prod(_t.dimensions());
if (total_size > 0) { if (total_size > 0) {
const Index first_dim = Eigen::internal::array_get<0>(_t.dimensions()); const Index first_dim = Eigen::internal::array_get<0>(_t.dimensions());
Map<const Array<Scalar, Dynamic, Dynamic, layout> > matrix(const_cast<Scalar*>(_t.data()), first_dim, Map<const Array<Scalar, Dynamic, Dynamic, layout> > matrix(_t.data(), first_dim,
total_size / first_dim); total_size / first_dim);
s << matrix; s << matrix;
return; return;

View File

@ -90,6 +90,16 @@ void test_tensor_ostream() {
test_tensor_ostream_impl<Scalar, rank, Layout>::run(); test_tensor_ostream_impl<Scalar, rank, Layout>::run();
} }
void test_const_tensor_ostream() {
Eigen::Tensor<float, 0> t;
t.setValues(1);
const Eigen::TensorMap<Eigen::Tensor<const float, 0, Eigen::RowMajor>, Eigen::Unaligned> t_const(
t.data(), Eigen::DSizes<Eigen::DenseIndex, 0>{});
std::ostringstream os;
os << t_const.format(Eigen::TensorIOFormat::Plain());
VERIFY(os.str() == "1");
}
EIGEN_DECLARE_TEST(cxx11_tensor_io) { EIGEN_DECLARE_TEST(cxx11_tensor_io) {
CALL_SUBTEST((test_tensor_ostream<float, 0, Eigen::ColMajor>())); CALL_SUBTEST((test_tensor_ostream<float, 0, Eigen::ColMajor>()));
CALL_SUBTEST((test_tensor_ostream<float, 1, Eigen::ColMajor>())); CALL_SUBTEST((test_tensor_ostream<float, 1, Eigen::ColMajor>()));
@ -126,4 +136,7 @@ EIGEN_DECLARE_TEST(cxx11_tensor_io) {
CALL_SUBTEST((test_tensor_ostream<std::complex<double>, 2, Eigen::ColMajor>())); CALL_SUBTEST((test_tensor_ostream<std::complex<double>, 2, Eigen::ColMajor>()));
CALL_SUBTEST((test_tensor_ostream<std::complex<float>, 2, Eigen::ColMajor>())); CALL_SUBTEST((test_tensor_ostream<std::complex<float>, 2, Eigen::ColMajor>()));
// Test printing TensorMap with const elements.
CALL_SUBTEST((test_const_tensor_ostream()));
} }