Started to add support for tensors of rank 0

This commit is contained in:
Benoit Steiner 2015-10-26 14:29:26 -07:00
parent 1f4c98abb1
commit 1c8312c811
4 changed files with 57 additions and 3 deletions

View File

@ -140,6 +140,12 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
}
#endif
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff() const
{
EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
return m_storage.data()[0];
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index index) const
{
eigen_internal_assert(index >= 0 && index < size());
@ -174,6 +180,12 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
}
#endif
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef()
{
EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
return m_storage.data()[0];
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index)
{
eigen_internal_assert(index >= 0 && index < size());
@ -234,6 +246,12 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
return coeff(index);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()() const
{
EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
return coeff();
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator[](Index index) const
{
// The bracket operator is only for vectors, use the parenthesis operator instead.
@ -295,6 +313,12 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
return coeffRef(index);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()()
{
EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
return coeffRef();
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator[](Index index)
{
// The bracket operator is only for vectors, use the parenthesis operator instead
@ -433,6 +457,13 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
resize(dims);
}
EIGEN_DEVICE_FUNC
void resize()
{
EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
// Nothing to do: rank 0 tensors have fixed size
}
/** Custom Dimension */
#ifdef EIGEN_HAS_SFINAE
template<typename CustomDimension,

View File

@ -33,7 +33,10 @@ std::ostream& operator << (std::ostream& os, const TensorBase<T, ReadOnlyAccesso
const Index total_size = internal::array_prod(tensor.dimensions());
// Print the tensor as a 1d vector or a 2d matrix.
if (internal::array_size<Dimensions>::value == 1) {
static const int rank = internal::array_size<Dimensions>::value;
if (rank == 0) {
os << tensor.coeff(0);
} else if (rank == 1) {
Map<const Array<Scalar, Dynamic, 1> > array(const_cast<Scalar*>(tensor.data()), total_size);
os << array;
} else {

View File

@ -55,6 +55,18 @@ struct Initializer<Derived, 1> {
}
};
template <typename Derived>
struct Initializer<Derived, 0> {
typedef typename traits<Derived>::Scalar InitList;
static void run(TensorEvaluator<Derived, DefaultDevice>& tensor,
Eigen::array<typename traits<Derived>::Index, traits<Derived>::NumDimensions>*/* indices*/,
const InitList& v) {
tensor.coeffRef(0) = v;
}
};
template <typename Derived, int N>
void initialize_tensor(TensorEvaluator<Derived, DefaultDevice>& tensor,
const typename Initializer<Derived, traits<Derived>::NumDimensions>::InitList& vals) {

View File

@ -71,7 +71,11 @@ class TensorStorage<T, DSizes<IndexType, NumIndices_>, Options_>
typedef DSizes<IndexType, NumIndices_> Dimensions;
typedef TensorStorage<T, DSizes<IndexType, NumIndices_>, Options_> Self;
EIGEN_DEVICE_FUNC TensorStorage() : m_data(0), m_dimensions() {}
EIGEN_DEVICE_FUNC TensorStorage() : m_data(0), m_dimensions() {
if (NumIndices_ == 0) {
m_data = internal::conditional_aligned_new_auto<T,(Options_&DontAlign)==0>(1);
}
}
EIGEN_DEVICE_FUNC TensorStorage(internal::constructor_without_unaligned_array_assert)
: m_data(0), m_dimensions(internal::template repeat<NumIndices_, Index>(0)) {}
EIGEN_DEVICE_FUNC TensorStorage(Index size, const array<Index, NumIndices_>& dimensions)
@ -101,13 +105,17 @@ class TensorStorage<T, DSizes<IndexType, NumIndices_>, Options_>
EIGEN_DEVICE_FUNC void resize(Index size, const array<Index, NumIndices_>& nbDimensions)
{
eigen_assert(size >= 1);
const Index currentSz = internal::array_prod(m_dimensions);
if(size != currentSz)
{
internal::conditional_aligned_delete_auto<T,(Options_&DontAlign)==0>(m_data, currentSz);
if (size)
m_data = internal::conditional_aligned_new_auto<T,(Options_&DontAlign)==0>(size);
else
else if (NumIndices_ == 0) {
m_data = internal::conditional_aligned_new_auto<T,(Options_&DontAlign)==0>(1);
}
else
m_data = 0;
EIGEN_INTERNAL_DENSE_STORAGE_CTOR_PLUGIN
}