mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Rewrote Eigen::dimensions_match to prevent a static assertion when the rank of the tensors is different.
This commit is contained in:
parent
e94f9eb637
commit
7a39439904
@ -401,15 +401,21 @@ template <std::size_t n, std::size_t V1, std::size_t V2, std::size_t V3, std::si
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
template <typename Dims1, typename Dims2, size_t n>
|
template <typename Dims1, typename Dims2, size_t n, size_t m>
|
||||||
struct sizes_match_up_to_dim {
|
struct sizes_match_up_to_dim {
|
||||||
|
static inline bool run(Dims1& dims1, Dims2& dims2) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template <typename Dims1, typename Dims2, size_t n>
|
||||||
|
struct sizes_match_up_to_dim<Dims1, Dims2, n, n> {
|
||||||
static inline bool run(Dims1& dims1, Dims2& dims2) {
|
static inline bool run(Dims1& dims1, Dims2& dims2) {
|
||||||
return (array_get<n>(dims1) == array_get<n>(dims2)) &
|
return (array_get<n>(dims1) == array_get<n>(dims2)) &
|
||||||
sizes_match_up_to_dim<Dims1, Dims2, n-1>::run(dims1, dims2);
|
sizes_match_up_to_dim<Dims1, Dims2, n-1, n-1>::run(dims1, dims2);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
template <typename Dims1, typename Dims2>
|
template <typename Dims1, typename Dims2>
|
||||||
struct sizes_match_up_to_dim<Dims1, Dims2, 0> {
|
struct sizes_match_up_to_dim<Dims1, Dims2, 0, 0> {
|
||||||
static inline bool run(Dims1& dims1, Dims2& dims2) {
|
static inline bool run(Dims1& dims1, Dims2& dims2) {
|
||||||
return (array_get<0>(dims1) == array_get<0>(dims2));
|
return (array_get<0>(dims1) == array_get<0>(dims2));
|
||||||
}
|
}
|
||||||
@ -420,10 +426,7 @@ struct sizes_match_up_to_dim<Dims1, Dims2, 0> {
|
|||||||
|
|
||||||
template <typename Dims1, typename Dims2>
|
template <typename Dims1, typename Dims2>
|
||||||
bool dimensions_match(Dims1& dims1, Dims2& dims2) {
|
bool dimensions_match(Dims1& dims1, Dims2& dims2) {
|
||||||
if (static_cast<size_t>(internal::array_size<Dims1>::value) != static_cast<size_t>(internal::array_size<Dims2>::value)) {
|
return internal::sizes_match_up_to_dim<Dims1, Dims2, internal::array_size<Dims1>::value-1, internal::array_size<Dims2>::value-1>::run(dims1, dims2);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return internal::sizes_match_up_to_dim<Dims1, Dims2, internal::array_size<Dims1>::value-1>::run(dims1, dims2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
@ -43,6 +43,10 @@ static void test_match()
|
|||||||
Eigen::DSizes<int, 3> dyn(2,3,7);
|
Eigen::DSizes<int, 3> dyn(2,3,7);
|
||||||
Eigen::Sizes<2,3,7> stat;
|
Eigen::Sizes<2,3,7> stat;
|
||||||
VERIFY_IS_EQUAL(Eigen::dimensions_match(dyn, stat), true);
|
VERIFY_IS_EQUAL(Eigen::dimensions_match(dyn, stat), true);
|
||||||
|
|
||||||
|
Eigen::DSizes<int, 3> dyn1(2,3,7);
|
||||||
|
Eigen::DSizes<int, 2> dyn2(2,3);
|
||||||
|
VERIFY_IS_EQUAL(Eigen::dimensions_match(dyn, stat), false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user