mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-27 07:29:52 +08:00
Improved support for fixed size tensors
This commit is contained in:
parent
670c71d906
commit
f0ce85b757
@ -375,6 +375,28 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
resize(dims);
|
resize(dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifndef EIGEN_EMULATE_CXX11_META_H
|
||||||
|
template <typename std::ptrdiff_t... Indices>
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
void resize(const Sizes<Indices...>& dimensions) {
|
||||||
|
array<Index, NumIndices> dims;
|
||||||
|
for (int i = 0; i < NumIndices; ++i) {
|
||||||
|
dims[i] = dimensions[i];
|
||||||
|
}
|
||||||
|
resize(dims);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5>
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
void resize(const Sizes<V1, V2, V3, V4, V5>& dimensions) {
|
||||||
|
array<Index, NumIndices> dims;
|
||||||
|
for (int i = 0; i < NumIndices; ++i) {
|
||||||
|
dims[i] = dimensions[i];
|
||||||
|
}
|
||||||
|
resize(dims);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
||||||
bool checkIndexRange(const array<Index, NumIndices>& indices) const
|
bool checkIndexRange(const array<Index, NumIndices>& indices) const
|
||||||
|
@ -69,6 +69,31 @@ struct fixed_size_tensor_index_linearization_helper<Index, NumIndices, 0, RowMaj
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename Index, std::size_t n>
|
||||||
|
struct fixed_size_tensor_index_extraction_helper
|
||||||
|
{
|
||||||
|
template <typename Dimensions> EIGEN_DEVICE_FUNC
|
||||||
|
static inline Index run(const Index index,
|
||||||
|
const Dimensions& dimensions)
|
||||||
|
{
|
||||||
|
const Index mult = (index == n) ? 1 : 0;
|
||||||
|
return array_get<n>(dimensions) * mult +
|
||||||
|
fixed_size_tensor_index_extraction_helper<Index, n - 1>::run(index, dimensions);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Index>
|
||||||
|
struct fixed_size_tensor_index_extraction_helper<Index, 0>
|
||||||
|
{
|
||||||
|
template <typename Dimensions> EIGEN_DEVICE_FUNC
|
||||||
|
static inline Index run(const Index index,
|
||||||
|
const Dimensions& dimensions)
|
||||||
|
{
|
||||||
|
const Index mult = (index == 0) ? 1 : 0;
|
||||||
|
return array_get<0>(dimensions) * mult;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
|
|
||||||
@ -99,6 +124,10 @@ struct Sizes : internal::numeric_list<std::ptrdiff_t, Indices...> {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t operator[] (const int index) const {
|
||||||
|
return internal::fixed_size_tensor_index_extraction_helper<std::ptrdiff_t, Base::count - 1>::run(index, *this);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T> Sizes& operator = (const T& /*other*/) {
|
template <typename T> Sizes& operator = (const T& /*other*/) {
|
||||||
// add assertion failure if the size of other is different
|
// add assertion failure if the size of other is different
|
||||||
return *this;
|
return *this;
|
||||||
@ -114,10 +143,12 @@ struct Sizes : internal::numeric_list<std::ptrdiff_t, Indices...> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
template <typename std::ptrdiff_t... Indices>
|
template <typename std::ptrdiff_t... Indices>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes<Indices...>&) {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes<Indices...>&) {
|
||||||
return Sizes<Indices...>::total_size;
|
return Sizes<Indices...>::total_size;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
@ -166,6 +197,24 @@ template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex operator[] (const int index) const {
|
||||||
|
switch (index) {
|
||||||
|
case 0:
|
||||||
|
return internal::get<0, Base>::value;
|
||||||
|
case 1:
|
||||||
|
return internal::get<1, Base>::value;
|
||||||
|
case 2:
|
||||||
|
return internal::get<2, Base>::value;
|
||||||
|
case 3:
|
||||||
|
return internal::get<3, Base>::value;
|
||||||
|
case 4:
|
||||||
|
return internal::get<4, Base>::value;
|
||||||
|
default:
|
||||||
|
eigen_assert(false && "index overflow");
|
||||||
|
return static_cast<std::size_t>(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T> Sizes& operator = (const T&) {
|
template <typename T> Sizes& operator = (const T&) {
|
||||||
// to do: check the size of other
|
// to do: check the size of other
|
||||||
return *this;
|
return *this;
|
||||||
@ -181,10 +230,12 @@ template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5>
|
template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_prod(const Sizes<V1, V2, V3, V4, V5>&) {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_prod(const Sizes<V1, V2, V3, V4, V5>&) {
|
||||||
return Sizes<V1, V2, V3, V4, V5>::total_size;
|
return Sizes<V1, V2, V3, V4, V5>::total_size;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user