mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-12 14:25:16 +08:00
Reimplemented the Tensor stream output.
This commit is contained in:
parent
2b9297196c
commit
f73c95c032
@ -76,6 +76,8 @@
|
||||
#include "src/Tensor/TensorIntDiv.h"
|
||||
#include "src/Tensor/TensorGlobalFunctions.h"
|
||||
|
||||
#include "src/Tensor/TensorIO.h"
|
||||
|
||||
#include "src/Tensor/TensorBase.h"
|
||||
#include "src/Tensor/TensorBlock.h"
|
||||
|
||||
@ -129,7 +131,7 @@
|
||||
#include "src/Tensor/TensorMap.h"
|
||||
#include "src/Tensor/TensorRef.h"
|
||||
|
||||
#include "src/Tensor/TensorIO.h"
|
||||
|
||||
|
||||
#include "../../../Eigen/src/Core/util/ReenableStupidWarnings.h"
|
||||
|
||||
|
@ -1794,6 +1794,45 @@ but you can easily cast the tensors to floats to do the division:
|
||||
|
||||
TODO
|
||||
|
||||
## Tensor Printing
|
||||
Tensors can be printed into a stream object (e.g. `std::cout`) using different formatting options.
|
||||
|
||||
Eigen::Tensor<float, 3> tensor3d = {4, 3, 2};
|
||||
tensor3d.setValues( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}, {{13, 14}, {15, 16}, {17, 18}}, {{19, 20}, {21, 22}, {23, 24}}} );
|
||||
std::cout << tensor3d.format(Eigen::TensorIOFormat::Plain()) << std::endl;
|
||||
==>
|
||||
1 2
|
||||
3 4
|
||||
5 6
|
||||
|
||||
7 8
|
||||
9 10
|
||||
11 12
|
||||
|
||||
13 14
|
||||
15 16
|
||||
17 18
|
||||
|
||||
19 20
|
||||
21 22
|
||||
23 24
|
||||
|
||||
|
||||
In the example, we used the predefined format `Eigen::TensorIOFormat::Plain`.
|
||||
Here is the list of all predefined formats from which you can choose:
|
||||
- `Eigen::TensorIOFormat::Plain()` for a plain output without braces. Different submatrices are separated by a blank line.
|
||||
- `Eigen::TensorIOFormat::Numpy()` for numpy-like output.
|
||||
- `Eigen::TensorIOFormat::Native()` for a `c++` like output which can be directly copy-pasted to setValues().
|
||||
- `Eigen::TensorIOFormat::Legacy()` for a backwards compatible printing of tensors.
|
||||
|
||||
If you send the tensor directly to the stream the default format is called which is `Eigen::IOFormats::Plain()`.
|
||||
|
||||
You can define your own format by explicitly providing a `Eigen::TensorIOFormat` class instance. Here, you can specify:
|
||||
- The overall prefix and suffix with `std::string tenPrefix` and `std::string tenSuffix`
|
||||
- The prefix, separator and suffix for each new element, row, matrix, 3d subtensor, ... with `std::vector<std::string> prefix`, `std::vector<std::string> separator` and `std::vector<std::string> suffix`. Note that the first entry in each of the vectors refer to the last dimension of the tensor, e.g. `separator[0]` will be printed between adjacent elements, `separator[1]` will be printed between adjacent matrices, ...
|
||||
- `char fill`: character which will be placed if the elements are aligned.
|
||||
- `int precision`
|
||||
- `int flags`: an OR-ed combination of flags, the default value is 0, the only currently available flag is `Eigen::DontAlignCols` which allows to disable the alignment of columns, resulting in faster code.
|
||||
|
||||
## Representation of scalar values
|
||||
|
||||
|
@ -962,6 +962,11 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
return TensorForcedEvalOp<const Derived>(derived());
|
||||
}
|
||||
|
||||
// Returns a formatted tensor ready for printing to a stream
|
||||
inline const TensorWithFormat<Derived,DerivedTraits::Layout,DerivedTraits::NumDimensions> format(const TensorIOFormat& fmt) const {
|
||||
return TensorWithFormat<Derived,DerivedTraits::Layout,DerivedTraits::NumDimensions>(derived(), fmt);
|
||||
}
|
||||
|
||||
#ifdef EIGEN_READONLY_TENSORBASE_PLUGIN
|
||||
#include EIGEN_READONLY_TENSORBASE_PLUGIN
|
||||
#endif
|
||||
|
@ -14,68 +14,361 @@
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
struct TensorIOFormat;
|
||||
|
||||
namespace internal {
|
||||
|
||||
// Print the tensor as a 2d matrix
|
||||
template <typename Tensor, int Rank>
|
||||
struct TensorPrinter {
|
||||
static void run (std::ostream& os, const Tensor& tensor) {
|
||||
typedef typename internal::remove_const<typename Tensor::Scalar>::type Scalar;
|
||||
typedef typename Tensor::Index Index;
|
||||
const Index total_size = internal::array_prod(tensor.dimensions());
|
||||
if (total_size > 0) {
|
||||
const Index first_dim = Eigen::internal::array_get<0>(tensor.dimensions());
|
||||
static const int layout = Tensor::Layout;
|
||||
Map<const Array<Scalar, Dynamic, Dynamic, layout> > matrix(const_cast<Scalar*>(tensor.data()), first_dim, total_size/first_dim);
|
||||
os << matrix;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Print the tensor as a vector
|
||||
template <typename Tensor>
|
||||
struct TensorPrinter<Tensor, 1> {
|
||||
static void run (std::ostream& os, const Tensor& tensor) {
|
||||
typedef typename internal::remove_const<typename Tensor::Scalar>::type Scalar;
|
||||
typedef typename Tensor::Index Index;
|
||||
const Index total_size = internal::array_prod(tensor.dimensions());
|
||||
if (total_size > 0) {
|
||||
Map<const Array<Scalar, Dynamic, 1> > array(const_cast<Scalar*>(tensor.data()), total_size);
|
||||
os << array;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Print the tensor as a scalar
|
||||
template <typename Tensor>
|
||||
struct TensorPrinter<Tensor, 0> {
|
||||
static void run (std::ostream& os, const Tensor& tensor) {
|
||||
os << tensor.coeff(0);
|
||||
}
|
||||
};
|
||||
template <typename Tensor, std::size_t rank>
|
||||
struct TensorPrinter;
|
||||
}
|
||||
|
||||
struct TensorIOFormat {
|
||||
TensorIOFormat(const std::vector<std::string>& _separator, const std::vector<std::string>& _prefix,
|
||||
const std::vector<std::string>& _suffix, int _precision = StreamPrecision, int _flags = 0,
|
||||
const std::string& _tenPrefix = "", const std::string& _tenSuffix = "", const char _fill = ' ')
|
||||
: tenPrefix(_tenPrefix),
|
||||
tenSuffix(_tenSuffix),
|
||||
prefix(_prefix),
|
||||
suffix(_suffix),
|
||||
separator(_separator),
|
||||
fill(_fill),
|
||||
precision(_precision),
|
||||
flags(_flags) {
|
||||
init_spacer();
|
||||
}
|
||||
|
||||
TensorIOFormat(int _precision = StreamPrecision, int _flags = 0, const std::string& _tenPrefix = "",
|
||||
const std::string& _tenSuffix = "", const char _fill = ' ')
|
||||
: tenPrefix(_tenPrefix), tenSuffix(_tenSuffix), fill(_fill), precision(_precision), flags(_flags) {
|
||||
// default values of prefix, suffix and separator
|
||||
prefix = {"", "["};
|
||||
suffix = {"", "]"};
|
||||
separator = {", ", "\n"};
|
||||
|
||||
init_spacer();
|
||||
}
|
||||
|
||||
void init_spacer() {
|
||||
if ((flags & DontAlignCols)) return;
|
||||
spacer.resize(prefix.size());
|
||||
spacer[0] = "";
|
||||
int i = int(tenPrefix.length()) - 1;
|
||||
while (i >= 0 && tenPrefix[i] != '\n') {
|
||||
spacer[0] += ' ';
|
||||
i--;
|
||||
}
|
||||
|
||||
for (std::size_t k = 1; k < prefix.size(); k++) {
|
||||
int i = int(prefix[k].length()) - 1;
|
||||
while (i >= 0 && prefix[k][i] != '\n') {
|
||||
spacer[k] += ' ';
|
||||
i--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline const TensorIOFormat Numpy() {
|
||||
std::vector<std::string> prefix = {"", "["};
|
||||
std::vector<std::string> suffix = {"", "]"};
|
||||
std::vector<std::string> separator = {" ", "\n"};
|
||||
return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "[", "]");
|
||||
}
|
||||
|
||||
static inline const TensorIOFormat Plain() {
|
||||
std::vector<std::string> separator = {" ", "\n", "\n", ""};
|
||||
std::vector<std::string> prefix = {""};
|
||||
std::vector<std::string> suffix = {""};
|
||||
return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "", "", ' ');
|
||||
}
|
||||
|
||||
static inline const TensorIOFormat Native() {
|
||||
std::vector<std::string> separator = {", ", ",\n", "\n"};
|
||||
std::vector<std::string> prefix = {"", "{"};
|
||||
std::vector<std::string> suffix = {"", "}"};
|
||||
return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "{", "}", ' ');
|
||||
}
|
||||
|
||||
static inline const TensorIOFormat Legacy() {
|
||||
TensorIOFormat LegacyFormat(StreamPrecision, 0, "", "", ' ');
|
||||
LegacyFormat.legacy_bit = true;
|
||||
return LegacyFormat;
|
||||
}
|
||||
|
||||
std::string tenPrefix;
|
||||
std::string tenSuffix;
|
||||
std::vector<std::string> prefix;
|
||||
std::vector<std::string> suffix;
|
||||
std::vector<std::string> separator;
|
||||
char fill;
|
||||
int precision;
|
||||
int flags;
|
||||
std::vector<std::string> spacer{};
|
||||
bool legacy_bit = false;
|
||||
};
|
||||
|
||||
template <typename T, int Layout, int rank>
|
||||
class TensorWithFormat;
|
||||
// specialize for Layout=ColMajor, Layout=RowMajor and rank=0.
|
||||
template <typename T, int rank>
|
||||
class TensorWithFormat<T, RowMajor, rank> {
|
||||
public:
|
||||
TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, RowMajor, rank>& wf) {
|
||||
// Evaluate the expression if needed
|
||||
typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
|
||||
TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
|
||||
Evaluator tensor(eval, DefaultDevice());
|
||||
tensor.evalSubExprsIfNeeded(NULL);
|
||||
internal::TensorPrinter<Evaluator, rank>::run(os, tensor, wf.t_format);
|
||||
// Cleanup.
|
||||
tensor.cleanup();
|
||||
return os;
|
||||
}
|
||||
|
||||
protected:
|
||||
T t_tensor;
|
||||
TensorIOFormat t_format;
|
||||
};
|
||||
|
||||
template <typename T, int rank>
|
||||
class TensorWithFormat<T, ColMajor, rank> {
|
||||
public:
|
||||
TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, rank>& wf) {
|
||||
// Switch to RowMajor storage and print afterwards
|
||||
typedef typename T::Index Index;
|
||||
std::array<Index, rank> shuffle;
|
||||
std::array<Index, rank> id;
|
||||
std::iota(id.begin(), id.end(), Index(0));
|
||||
std::copy(id.begin(), id.end(), shuffle.rbegin());
|
||||
auto tensor_row_major = wf.t_tensor.swap_layout().shuffle(shuffle);
|
||||
|
||||
// Evaluate the expression if needed
|
||||
typedef TensorEvaluator<const TensorForcedEvalOp<const decltype(tensor_row_major)>, DefaultDevice> Evaluator;
|
||||
TensorForcedEvalOp<const decltype(tensor_row_major)> eval = tensor_row_major.eval();
|
||||
Evaluator tensor(eval, DefaultDevice());
|
||||
tensor.evalSubExprsIfNeeded(NULL);
|
||||
internal::TensorPrinter<Evaluator, rank>::run(os, tensor, wf.t_format);
|
||||
// Cleanup.
|
||||
tensor.cleanup();
|
||||
return os;
|
||||
}
|
||||
|
||||
protected:
|
||||
T t_tensor;
|
||||
TensorIOFormat t_format;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator << (std::ostream& os, const TensorBase<T, ReadOnlyAccessors>& expr) {
|
||||
typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
|
||||
typedef typename Evaluator::Dimensions Dimensions;
|
||||
class TensorWithFormat<T, ColMajor, 0> {
|
||||
public:
|
||||
TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {}
|
||||
|
||||
// Evaluate the expression if needed
|
||||
TensorForcedEvalOp<const T> eval = expr.eval();
|
||||
Evaluator tensor(eval, DefaultDevice());
|
||||
tensor.evalSubExprsIfNeeded(NULL);
|
||||
friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, 0>& wf) {
|
||||
// Evaluate the expression if needed
|
||||
typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
|
||||
TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
|
||||
Evaluator tensor(eval, DefaultDevice());
|
||||
tensor.evalSubExprsIfNeeded(NULL);
|
||||
internal::TensorPrinter<Evaluator, 0>::run(os, tensor, wf.t_format);
|
||||
// Cleanup.
|
||||
tensor.cleanup();
|
||||
return os;
|
||||
}
|
||||
|
||||
// Print the result
|
||||
static const int rank = internal::array_size<Dimensions>::value;
|
||||
internal::TensorPrinter<Evaluator, rank>::run(os, tensor);
|
||||
protected:
|
||||
T t_tensor;
|
||||
TensorIOFormat t_format;
|
||||
};
|
||||
|
||||
// Cleanup.
|
||||
tensor.cleanup();
|
||||
return os;
|
||||
namespace internal {
|
||||
template <typename Tensor, std::size_t rank>
|
||||
struct TensorPrinter {
|
||||
static void run(std::ostream& s, const Tensor& _t, const TensorIOFormat& fmt) {
|
||||
typedef typename Tensor::Scalar Scalar;
|
||||
typedef typename Tensor::Index Index;
|
||||
static const int layout = Tensor::Layout;
|
||||
// backwards compatibility case: print tensor after reshaping to matrix of size dim(0) x
|
||||
// (dim(1)*dim(2)*...*dim(rank-1)).
|
||||
if (fmt.legacy_bit) {
|
||||
const Index total_size = internal::array_prod(_t.dimensions());
|
||||
if (total_size > 0) {
|
||||
const Index first_dim = Eigen::internal::array_get<0>(_t.dimensions());
|
||||
Map<const Array<Scalar, Dynamic, Dynamic, layout> > matrix(const_cast<Scalar*>(_t.data()), first_dim,
|
||||
total_size / first_dim);
|
||||
s << matrix;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
assert(layout == RowMajor);
|
||||
typedef typename conditional<is_same<Scalar, char>::value || is_same<Scalar, unsigned char>::value ||
|
||||
is_same<Scalar, numext::int8_t>::value || is_same<Scalar, numext::uint8_t>::value,
|
||||
int,
|
||||
typename conditional<is_same<Scalar, std::complex<char> >::value ||
|
||||
is_same<Scalar, std::complex<unsigned char> >::value ||
|
||||
is_same<Scalar, std::complex<numext::int8_t> >::value ||
|
||||
is_same<Scalar, std::complex<numext::uint8_t> >::value,
|
||||
std::complex<int>, const Scalar&>::type>::type PrintType;
|
||||
|
||||
const Index total_size = array_prod(_t.dimensions());
|
||||
|
||||
std::streamsize explicit_precision;
|
||||
if (fmt.precision == StreamPrecision) {
|
||||
explicit_precision = 0;
|
||||
} else if (fmt.precision == FullPrecision) {
|
||||
if (NumTraits<Scalar>::IsInteger) {
|
||||
explicit_precision = 0;
|
||||
} else {
|
||||
explicit_precision = significant_decimals_impl<Scalar>::run();
|
||||
}
|
||||
} else {
|
||||
explicit_precision = fmt.precision;
|
||||
}
|
||||
|
||||
std::streamsize old_precision = 0;
|
||||
if (explicit_precision) old_precision = s.precision(explicit_precision);
|
||||
|
||||
Index width = 0;
|
||||
|
||||
bool align_cols = !(fmt.flags & DontAlignCols);
|
||||
if (align_cols) {
|
||||
// compute the largest width
|
||||
for (Index i = 0; i < total_size; i++) {
|
||||
std::stringstream sstr;
|
||||
sstr.copyfmt(s);
|
||||
sstr << static_cast<PrintType>(_t.data()[i]);
|
||||
width = std::max<Index>(width, Index(sstr.str().length()));
|
||||
}
|
||||
}
|
||||
std::streamsize old_width = s.width();
|
||||
char old_fill_character = s.fill();
|
||||
|
||||
s << fmt.tenPrefix;
|
||||
for (Index i = 0; i < total_size; i++) {
|
||||
std::array<bool, rank> is_at_end{};
|
||||
std::array<bool, rank> is_at_begin{};
|
||||
|
||||
// is the ith element the end of an coeff (always true), of a row, of a matrix, ...?
|
||||
for (std::size_t k = 0; k < rank; k++) {
|
||||
if ((i + 1) % (std::accumulate(_t.dimensions().rbegin(), _t.dimensions().rbegin() + k, 1,
|
||||
std::multiplies<Index>())) ==
|
||||
0) {
|
||||
is_at_end[k] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// is the ith element the begin of an coeff (always true), of a row, of a matrix, ...?
|
||||
for (std::size_t k = 0; k < rank; k++) {
|
||||
if (i % (std::accumulate(_t.dimensions().rbegin(), _t.dimensions().rbegin() + k, 1,
|
||||
std::multiplies<Index>())) ==
|
||||
0) {
|
||||
is_at_begin[k] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// do we have a line break?
|
||||
bool is_at_begin_after_newline = false;
|
||||
for (std::size_t k = 0; k < rank; k++) {
|
||||
if (is_at_begin[k]) {
|
||||
std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
|
||||
if (fmt.separator[separator_index].find('\n') != std::string::npos) {
|
||||
is_at_begin_after_newline = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool is_at_end_before_newline = false;
|
||||
for (std::size_t k = 0; k < rank; k++) {
|
||||
if (is_at_end[k]) {
|
||||
std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
|
||||
if (fmt.separator[separator_index].find('\n') != std::string::npos) {
|
||||
is_at_end_before_newline = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::stringstream suffix, prefix, separator;
|
||||
for (std::size_t k = 0; k < rank; k++) {
|
||||
std::size_t suffix_index = (k < fmt.suffix.size()) ? k : fmt.suffix.size() - 1;
|
||||
if (is_at_end[k]) {
|
||||
suffix << fmt.suffix[suffix_index];
|
||||
}
|
||||
}
|
||||
for (std::size_t k = 0; k < rank; k++) {
|
||||
std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
|
||||
if (is_at_end[k] and
|
||||
(!is_at_end_before_newline or fmt.separator[separator_index].find('\n') != std::string::npos)) {
|
||||
separator << fmt.separator[separator_index];
|
||||
}
|
||||
}
|
||||
for (std::size_t k = 0; k < rank; k++) {
|
||||
std::size_t spacer_index = (k < fmt.spacer.size()) ? k : fmt.spacer.size() - 1;
|
||||
if (i != 0 and is_at_begin_after_newline and (!is_at_begin[k] or k == 0)) {
|
||||
prefix << fmt.spacer[spacer_index];
|
||||
}
|
||||
}
|
||||
for (int k = rank - 1; k >= 0; k--) {
|
||||
std::size_t prefix_index = (static_cast<std::size_t>(k) < fmt.prefix.size()) ? k : fmt.prefix.size() - 1;
|
||||
if (is_at_begin[k]) {
|
||||
prefix << fmt.prefix[prefix_index];
|
||||
}
|
||||
}
|
||||
|
||||
s << prefix.str();
|
||||
if (width) {
|
||||
s.fill(fmt.fill);
|
||||
s.width(width);
|
||||
s << std::right;
|
||||
}
|
||||
s << _t.data()[i];
|
||||
s << suffix.str();
|
||||
if (i < total_size - 1) {
|
||||
s << separator.str();
|
||||
}
|
||||
}
|
||||
s << fmt.tenSuffix;
|
||||
if (explicit_precision) s.precision(old_precision);
|
||||
if (width) {
|
||||
s.fill(old_fill_character);
|
||||
s.width(old_width);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tensor>
|
||||
struct TensorPrinter<Tensor, 0> {
|
||||
static void run(std::ostream& s, const Tensor& _t, const TensorIOFormat& fmt) {
|
||||
typedef typename Tensor::Scalar Scalar;
|
||||
|
||||
std::streamsize explicit_precision;
|
||||
if (fmt.precision == StreamPrecision) {
|
||||
explicit_precision = 0;
|
||||
} else if (fmt.precision == FullPrecision) {
|
||||
if (NumTraits<Scalar>::IsInteger) {
|
||||
explicit_precision = 0;
|
||||
} else {
|
||||
explicit_precision = significant_decimals_impl<Scalar>::run();
|
||||
}
|
||||
} else {
|
||||
explicit_precision = fmt.precision;
|
||||
}
|
||||
|
||||
std::streamsize old_precision = 0;
|
||||
if (explicit_precision) old_precision = s.precision(explicit_precision);
|
||||
|
||||
s << fmt.tenPrefix << _t.coeff(0) << fmt.tenSuffix;
|
||||
if (explicit_precision) s.precision(old_precision);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& s, const TensorBase<T, ReadOnlyAccessors>& t) {
|
||||
s << t.format(TensorIOFormat::Plain());
|
||||
return s;
|
||||
}
|
||||
} // end namespace Eigen
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_IO_H
|
||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_IO_H
|
||||
|
@ -6,131 +6,124 @@
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#include "main.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <Eigen/CXX11/Tensor>
|
||||
|
||||
template <typename Scalar, int rank, int Layout>
|
||||
struct test_tensor_ostream_impl {};
|
||||
|
||||
template<int DataLayout>
|
||||
static void test_output_0d()
|
||||
{
|
||||
Tensor<int, 0, DataLayout> tensor;
|
||||
tensor() = 123;
|
||||
|
||||
std::stringstream os;
|
||||
os << tensor;
|
||||
|
||||
std::string expected("123");
|
||||
VERIFY_IS_EQUAL(std::string(os.str()), expected);
|
||||
}
|
||||
|
||||
|
||||
template<int DataLayout>
|
||||
static void test_output_1d()
|
||||
{
|
||||
Tensor<int, 1, DataLayout> tensor(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
tensor(i) = i;
|
||||
template<typename Scalar, int Layout>
|
||||
struct test_tensor_ostream_impl<Scalar, 0, Layout> {
|
||||
static void run() {
|
||||
Eigen::Tensor<Scalar, 0> t;
|
||||
t.setValues(1);
|
||||
std::ostringstream os;
|
||||
os << t.format(Eigen::TensorIOFormat::Plain());
|
||||
VERIFY(os.str() == "1");
|
||||
}
|
||||
};
|
||||
|
||||
std::stringstream os;
|
||||
os << tensor;
|
||||
|
||||
std::string expected("0\n1\n2\n3\n4");
|
||||
VERIFY_IS_EQUAL(std::string(os.str()), expected);
|
||||
|
||||
Eigen::Tensor<double,1,DataLayout> empty_tensor(0);
|
||||
std::stringstream empty_os;
|
||||
empty_os << empty_tensor;
|
||||
std::string empty_string;
|
||||
VERIFY_IS_EQUAL(std::string(empty_os.str()), empty_string);
|
||||
}
|
||||
|
||||
|
||||
template<int DataLayout>
|
||||
static void test_output_2d()
|
||||
{
|
||||
Tensor<int, 2, DataLayout> tensor(5, 3);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
tensor(i, j) = i*j;
|
||||
}
|
||||
template<typename Scalar, int Layout>
|
||||
struct test_tensor_ostream_impl<Scalar, 1, Layout> {
|
||||
static void run() {
|
||||
Eigen::Tensor<Scalar, 1> t = {3};
|
||||
t.setValues({1, 2, 3});
|
||||
std::ostringstream os;
|
||||
os << t.format(Eigen::TensorIOFormat::Plain());
|
||||
VERIFY(os.str() == "1 2 3");
|
||||
}
|
||||
};
|
||||
|
||||
std::stringstream os;
|
||||
os << tensor;
|
||||
|
||||
std::string expected("0 0 0\n0 1 2\n0 2 4\n0 3 6\n0 4 8");
|
||||
VERIFY_IS_EQUAL(std::string(os.str()), expected);
|
||||
}
|
||||
|
||||
|
||||
template<int DataLayout>
|
||||
static void test_output_expr()
|
||||
{
|
||||
Tensor<int, 1, DataLayout> tensor1(5);
|
||||
Tensor<int, 1, DataLayout> tensor2(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
tensor1(i) = i;
|
||||
tensor2(i) = 7;
|
||||
template<typename Scalar, int Layout>
|
||||
struct test_tensor_ostream_impl<Scalar, 2, Layout> {
|
||||
static void run() {
|
||||
Eigen::Tensor<Scalar, 2> t = {3, 2};
|
||||
t.setValues({{1, 2}, {3, 4}, {5, 6}});
|
||||
std::ostringstream os;
|
||||
os << t.format(Eigen::TensorIOFormat::Plain());
|
||||
VERIFY(os.str() == "1 2\n3 4\n5 6");
|
||||
}
|
||||
};
|
||||
|
||||
std::stringstream os;
|
||||
os << tensor1 + tensor2;
|
||||
|
||||
std::string expected(" 7\n 8\n 9\n10\n11");
|
||||
VERIFY_IS_EQUAL(std::string(os.str()), expected);
|
||||
}
|
||||
|
||||
|
||||
template<int DataLayout>
|
||||
static void test_output_string()
|
||||
{
|
||||
Tensor<std::string, 2, DataLayout> tensor(5, 3);
|
||||
tensor.setConstant(std::string("foo"));
|
||||
|
||||
std::cout << tensor << std::endl;
|
||||
|
||||
std::stringstream os;
|
||||
os << tensor;
|
||||
|
||||
std::string expected("foo foo foo\nfoo foo foo\nfoo foo foo\nfoo foo foo\nfoo foo foo");
|
||||
VERIFY_IS_EQUAL(std::string(os.str()), expected);
|
||||
}
|
||||
|
||||
|
||||
template<int DataLayout>
|
||||
static void test_output_const()
|
||||
{
|
||||
Tensor<int, 1, DataLayout> tensor(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
tensor(i) = i;
|
||||
template<typename Scalar, int Layout>
|
||||
struct test_tensor_ostream_impl<Scalar, 3, Layout> {
|
||||
static void run() {
|
||||
Eigen::Tensor<Scalar, 3> t = {4, 3, 2};
|
||||
t.setValues({{{1, 2}, {3, 4}, {5, 6}},
|
||||
{{7, 8}, {9, 10}, {11, 12}},
|
||||
{{13, 14}, {15, 16}, {17, 18}},
|
||||
{{19, 20}, {21, 22}, {23, 24}}});
|
||||
std::ostringstream os;
|
||||
os << t.format(Eigen::TensorIOFormat::Plain());
|
||||
VERIFY(os.str() == " 1 2\n 3 4\n 5 6\n\n 7 8\n 9 10\n11 12\n\n13 14\n15 16\n17 18\n\n19 20\n21 22\n23 24");
|
||||
}
|
||||
};
|
||||
|
||||
TensorMap<Tensor<const int, 1, DataLayout> > tensor_map(tensor.data(), 5);
|
||||
template<int Layout>
|
||||
struct test_tensor_ostream_impl<bool, 2, Layout> {
|
||||
static void run() {
|
||||
Eigen::Tensor<bool, 2> t = {3, 2};
|
||||
t.setValues({{false, true}, {true, false}, {false, false}});
|
||||
std::ostringstream os;
|
||||
os << t.format(Eigen::TensorIOFormat::Plain());
|
||||
VERIFY(os.str() == "0 1\n1 0\n0 0");
|
||||
}
|
||||
};
|
||||
|
||||
std::stringstream os;
|
||||
os << tensor_map;
|
||||
template<typename Scalar, int Layout>
|
||||
struct test_tensor_ostream_impl<std::complex<Scalar>, 2, Layout> {
|
||||
static void run() {
|
||||
Eigen::Tensor<std::complex<Scalar>, 2> t = {3, 2};
|
||||
t.setValues({{std::complex<Scalar>(1, 2), std::complex<Scalar>(12, 3)},
|
||||
{std::complex<Scalar>(-4, 2), std::complex<Scalar>(0, 5)},
|
||||
{std::complex<Scalar>(-1, 4), std::complex<Scalar>(5, 27)}});
|
||||
std::ostringstream os;
|
||||
os << t.format(Eigen::TensorIOFormat::Plain());
|
||||
VERIFY(os.str() == " (1,2) (12,3)\n(-4,2) (0,5)\n(-1,4) (5,27)");
|
||||
}
|
||||
};
|
||||
|
||||
std::string expected("0\n1\n2\n3\n4");
|
||||
VERIFY_IS_EQUAL(std::string(os.str()), expected);
|
||||
template <typename Scalar, int rank, int Layout>
|
||||
void test_tensor_ostream() {
|
||||
test_tensor_ostream_impl<Scalar, rank, Layout>::run();
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(cxx11_tensor_io) {
|
||||
CALL_SUBTEST((test_tensor_ostream<float, 0, Eigen::ColMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<float, 1, Eigen::ColMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<float, 2, Eigen::ColMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<float, 3, Eigen::ColMajor>()));
|
||||
|
||||
EIGEN_DECLARE_TEST(cxx11_tensor_io)
|
||||
{
|
||||
CALL_SUBTEST(test_output_0d<ColMajor>());
|
||||
CALL_SUBTEST(test_output_0d<RowMajor>());
|
||||
CALL_SUBTEST(test_output_1d<ColMajor>());
|
||||
CALL_SUBTEST(test_output_1d<RowMajor>());
|
||||
CALL_SUBTEST(test_output_2d<ColMajor>());
|
||||
CALL_SUBTEST(test_output_2d<RowMajor>());
|
||||
CALL_SUBTEST(test_output_expr<ColMajor>());
|
||||
CALL_SUBTEST(test_output_expr<RowMajor>());
|
||||
CALL_SUBTEST(test_output_string<ColMajor>());
|
||||
CALL_SUBTEST(test_output_string<RowMajor>());
|
||||
CALL_SUBTEST(test_output_const<ColMajor>());
|
||||
CALL_SUBTEST(test_output_const<RowMajor>());
|
||||
CALL_SUBTEST((test_tensor_ostream<double, 0, Eigen::ColMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<double, 1, Eigen::ColMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<double, 2, Eigen::ColMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<double, 3, Eigen::ColMajor>()));
|
||||
|
||||
CALL_SUBTEST((test_tensor_ostream<int, 0, Eigen::ColMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<int, 1, Eigen::ColMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<int, 2, Eigen::ColMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<int, 3, Eigen::ColMajor>()));
|
||||
|
||||
CALL_SUBTEST((test_tensor_ostream<float, 0, Eigen::RowMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<float, 1, Eigen::RowMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<float, 2, Eigen::RowMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<float, 3, Eigen::RowMajor>()));
|
||||
|
||||
CALL_SUBTEST((test_tensor_ostream<double, 0, Eigen::RowMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<double, 1, Eigen::RowMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<double, 2, Eigen::RowMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<double, 3, Eigen::RowMajor>()));
|
||||
|
||||
CALL_SUBTEST((test_tensor_ostream<int, 0, Eigen::RowMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<int, 1, Eigen::RowMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<int, 2, Eigen::RowMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<int, 3, Eigen::RowMajor>()));
|
||||
|
||||
CALL_SUBTEST((test_tensor_ostream<bool, 2, Eigen::ColMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<bool, 2, Eigen::RowMajor>()));
|
||||
|
||||
CALL_SUBTEST((test_tensor_ostream<std::complex<double>, 2, Eigen::ColMajor>()));
|
||||
CALL_SUBTEST((test_tensor_ostream<std::complex<float>, 2, Eigen::ColMajor>()));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user