mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-15 07:10:37 +08:00
Added primitives to compare tensor dimensions
This commit is contained in:
parent
9b7a6f0122
commit
40bb98e76a
@ -210,6 +210,60 @@ struct DSizes : array<DenseIndex, NumDims> {
|
||||
};
|
||||
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <typename DenseIndex, std::size_t NumDims> struct array_size<const DSizes<DenseIndex, NumDims> > {
|
||||
static const size_t value = NumDims;
|
||||
};
|
||||
template <typename DenseIndex, std::size_t NumDims> struct array_size<DSizes<DenseIndex, NumDims> > {
|
||||
static const size_t value = NumDims;
|
||||
};
|
||||
#ifndef EIGEN_EMULATE_CXX11_META_H
|
||||
template <typename std::size_t... Indices> struct array_size<const Sizes<Indices...> > {
|
||||
static const size_t value = Sizes<Indices...>::count;
|
||||
};
|
||||
template <typename std::size_t... Indices> struct array_size<Sizes<Indices...> > {
|
||||
static const size_t value = Sizes<Indices...>::count;
|
||||
};
|
||||
#else
|
||||
template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> struct array_size<const Sizes<V1,V2,V3,V4,V5> > {
|
||||
static const size_t value = Sizes<V1,V2,V3,V4,V5>::count;
|
||||
};
|
||||
template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> struct array_size<Sizes<V1,V2,V3,V4,V5> > {
|
||||
static const size_t value = Sizes<V1,V2,V3,V4,V5>::count;
|
||||
};
|
||||
template <std::size_t n, std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_get(const Sizes<V1,V2,V3,V4,V5>& a) {
|
||||
return get<n, typename Sizes<V1,V2,V3,V4,V5>::Base>::value;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
template <typename Dims1, typename Dims2, size_t n>
|
||||
struct sizes_match_up_to_dim {
|
||||
static inline bool run(Dims1& dims1, Dims2& dims2) {
|
||||
return (array_get<n>(dims1) == array_get<n>(dims2)) &
|
||||
sizes_match_up_to_dim<Dims1, Dims2, n-1>::run(dims1, dims2);
|
||||
}
|
||||
};
|
||||
template <typename Dims1, typename Dims2>
|
||||
struct sizes_match_up_to_dim<Dims1, Dims2, 0> {
|
||||
static inline bool run(Dims1& dims1, Dims2& dims2) {
|
||||
return (array_get<0>(dims1) == array_get<0>(dims2));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Dims1, typename Dims2>
|
||||
bool dimensions_match(Dims1& dims1, Dims2& dims2) {
|
||||
if (array_size<Dims1>::value != array_size<Dims2>::value) {
|
||||
return false;
|
||||
}
|
||||
return sizes_match_up_to_dim<Dims1, Dims2, array_size<Dims1>::value-1>::run(dims1, dims2);
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
|
||||
|
Loading…
Reference in New Issue
Block a user