diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h b/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h index 38a833f82..3db692ac6 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h @@ -17,34 +17,58 @@ template<> struct significant_decimals_impl : significant_decimals_default_impl {}; -} -template -std::ostream& operator << (std::ostream& os, const TensorBase& expr) { - // Evaluate the expression if needed - TensorForcedEvalOp eval = expr.eval(); - TensorEvaluator, DefaultDevice> tensor(eval, DefaultDevice()); - tensor.evalSubExprsIfNeeded(NULL); - - typedef typename internal::remove_const::type Scalar; - typedef typename T::Index Index; - typedef typename TensorEvaluator, DefaultDevice>::Dimensions Dimensions; - const Index total_size = internal::array_prod(tensor.dimensions()); - - // Print the tensor as a 1d vector or a 2d matrix. - static const int rank = internal::array_size::value; - if (rank == 0) { - os << tensor.coeff(0); - } else if (rank == 1) { - Map > array(const_cast(tensor.data()), total_size); - os << array; - } else { +// Print the tensor as a 2d matrix +template +struct TensorPrinter { + static void run (std::ostream& os, const Tensor& tensor) { + typedef typename internal::remove_const::type Scalar; + typedef typename Tensor::Index Index; + const Index total_size = internal::array_prod(tensor.dimensions()); const Index first_dim = Eigen::internal::array_get<0>(tensor.dimensions()); - static const int layout = TensorEvaluator, DefaultDevice>::Layout; + static const int layout = Tensor::Layout; Map > matrix(const_cast(tensor.data()), first_dim, total_size/first_dim); os << matrix; } +}; + + +// Print the tensor as a vector +template +struct TensorPrinter { + static void run (std::ostream& os, const Tensor& tensor) { + typedef typename internal::remove_const::type Scalar; + typedef typename Tensor::Index Index; + const Index total_size = internal::array_prod(tensor.dimensions()); + Map > array(const_cast(tensor.data()), total_size); + os << array; + } +}; + + +// Print the tensor as a scalar +template +struct TensorPrinter { + static void run (std::ostream& os, const Tensor& tensor) { + os << tensor.coeff(0); + } +}; +} + +template +std::ostream& operator << (std::ostream& os, const TensorBase& expr) { + typedef TensorEvaluator, DefaultDevice> Evaluator; + typedef typename Evaluator::Dimensions Dimensions; + + // Evaluate the expression if needed + TensorForcedEvalOp eval = expr.eval(); + Evaluator tensor(eval, DefaultDevice()); + tensor.evalSubExprsIfNeeded(NULL); + + // Print the result + static const int rank = internal::array_size::value; + internal::TensorPrinter::run(os, tensor); // Cleanup. tensor.cleanup(); diff --git a/unsupported/test/cxx11_tensor_io.cpp b/unsupported/test/cxx11_tensor_io.cpp index 8bbcf7089..8267dcadd 100644 --- a/unsupported/test/cxx11_tensor_io.cpp +++ b/unsupported/test/cxx11_tensor_io.cpp @@ -13,6 +13,20 @@ #include +template +static void test_output_0d() +{ + Tensor tensor; + tensor() = 123; + + std::stringstream os; + os << tensor; + + std::string expected("123"); + VERIFY_IS_EQUAL(std::string(os.str()), expected); +} + + template static void test_output_1d() { @@ -101,6 +115,8 @@ static void test_output_const() void test_cxx11_tensor_io() { + CALL_SUBTEST(test_output_0d()); + CALL_SUBTEST(test_output_0d()); CALL_SUBTEST(test_output_1d()); CALL_SUBTEST(test_output_1d()); CALL_SUBTEST(test_output_2d());