mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Added more tests to cover tensor reductions
This commit is contained in:
parent
9dfdbd7e56
commit
5a6ea4edf6
@ -37,7 +37,11 @@ template <typename T> struct SumReducer
|
|||||||
return accum;
|
return accum;
|
||||||
}
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizePacket(const T saccum, const Packet& vaccum) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet finalizePacket(const Packet& vaccum) const {
|
||||||
|
return vaccum;
|
||||||
|
}
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
|
||||||
return saccum + predux(vaccum);
|
return saccum + predux(vaccum);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -45,16 +49,16 @@ template <typename T> struct SumReducer
|
|||||||
template <typename T> struct MeanReducer
|
template <typename T> struct MeanReducer
|
||||||
{
|
{
|
||||||
static const bool PacketAccess = true;
|
static const bool PacketAccess = true;
|
||||||
MeanReducer() : count_(0) { }
|
MeanReducer() : scalarCount_(0), packetCount_(0) { }
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) {
|
||||||
(*accum) += t;
|
(*accum) += t;
|
||||||
count_++;
|
scalarCount_++;
|
||||||
}
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) {
|
||||||
(*accum) = padd<Packet>(*accum, p);
|
(*accum) = padd<Packet>(*accum, p);
|
||||||
count_ += packet_traits<Packet>::size;
|
packetCount_++;
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
||||||
@ -65,15 +69,20 @@ template <typename T> struct MeanReducer
|
|||||||
return pset1<Packet>(0);
|
return pset1<Packet>(0);
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
|
||||||
return accum / count_;
|
return accum / scalarCount_;
|
||||||
}
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizePacket(const T saccum, const Packet& vaccum) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet finalizePacket(const Packet& vaccum) const {
|
||||||
return (saccum + predux(vaccum)) / count_;
|
return pdiv(vaccum, pset1<Packet>(packetCount_));
|
||||||
|
}
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
|
||||||
|
return (saccum + predux(vaccum)) / (scalarCount_ + packetCount_ * packet_traits<Packet>::size);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
int count_;
|
int scalarCount_;
|
||||||
|
int packetCount_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T> struct MaxReducer
|
template <typename T> struct MaxReducer
|
||||||
@ -99,7 +108,11 @@ template <typename T> struct MaxReducer
|
|||||||
return accum;
|
return accum;
|
||||||
}
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizePacket(const T saccum, const Packet& vaccum) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet finalizePacket(const Packet& vaccum) const {
|
||||||
|
return vaccum;
|
||||||
|
}
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
|
||||||
return (std::max)(saccum, predux_max(vaccum));
|
return (std::max)(saccum, predux_max(vaccum));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -127,7 +140,11 @@ template <typename T> struct MinReducer
|
|||||||
return accum;
|
return accum;
|
||||||
}
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizePacket(const T saccum, const Packet& vaccum) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet finalizePacket(const Packet& vaccum) const {
|
||||||
|
return vaccum;
|
||||||
|
}
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
|
||||||
return (std::min)(saccum, predux_min(vaccum));
|
return (std::min)(saccum, predux_min(vaccum));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -156,7 +173,11 @@ template <typename T> struct ProdReducer
|
|||||||
return accum;
|
return accum;
|
||||||
}
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizePacket(const T saccum, const Packet& vaccum) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet finalizePacket(const Packet& vaccum) const {
|
||||||
|
return vaccum;
|
||||||
|
}
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
|
||||||
return saccum * predux_mul(vaccum);
|
return saccum * predux_mul(vaccum);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -181,7 +181,7 @@ template<typename FirstType, typename... OtherTypes> size_t array_prod(const Ind
|
|||||||
result *= sizes[i];
|
result *= sizes[i];
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
};
|
||||||
|
|
||||||
template<typename FirstType, typename... OtherTypes> struct array_size<IndexList<FirstType, OtherTypes...> > {
|
template<typename FirstType, typename... OtherTypes> struct array_size<IndexList<FirstType, OtherTypes...> > {
|
||||||
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
|
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
|
||||||
@ -307,6 +307,52 @@ struct index_statically_ne<const IndexList<FirstType, OtherTypes...> > {
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct index_statically_gt {
|
||||||
|
constexpr bool operator() (DenseIndex, DenseIndex) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct index_statically_gt<IndexList<FirstType, OtherTypes...> > {
|
||||||
|
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
||||||
|
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
|
IndexList<FirstType, OtherTypes...>()[i] > value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct index_statically_gt<const IndexList<FirstType, OtherTypes...> > {
|
||||||
|
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
||||||
|
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
|
IndexList<FirstType, OtherTypes...>()[i] > value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct index_statically_lt {
|
||||||
|
constexpr bool operator() (DenseIndex, DenseIndex) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct index_statically_lt<IndexList<FirstType, OtherTypes...> > {
|
||||||
|
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
||||||
|
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
|
IndexList<FirstType, OtherTypes...>()[i] < value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct index_statically_lt<const IndexList<FirstType, OtherTypes...> > {
|
||||||
|
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
||||||
|
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
|
IndexList<FirstType, OtherTypes...>()[i] < value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
@ -351,6 +397,20 @@ struct index_statically_ne {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct index_statically_gt {
|
||||||
|
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct index_statically_lt {
|
||||||
|
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
|
@ -369,6 +369,37 @@ static void test_innermost_first_dims() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int DataLayout>
|
||||||
|
static void test_reduce_middle_dims() {
|
||||||
|
Tensor<float, 4, DataLayout> in(72, 53, 97, 113);
|
||||||
|
Tensor<float, 2, DataLayout> out(72, 53);
|
||||||
|
in.setRandom();
|
||||||
|
|
||||||
|
// Reduce on the innermost dimensions.
|
||||||
|
#if __cplusplus <= 199711L
|
||||||
|
array<int, 2> reduction_axis;
|
||||||
|
reduction_axis[0] = 1;
|
||||||
|
reduction_axis[1] = 2;
|
||||||
|
#else
|
||||||
|
// This triggers the use of packets for RowMajor.
|
||||||
|
Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2>> reduction_axis;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
out = in.maximum(reduction_axis);
|
||||||
|
|
||||||
|
for (int i = 0; i < 72; ++i) {
|
||||||
|
for (int j = 0; j < 113; ++j) {
|
||||||
|
float expected = -1e10f;
|
||||||
|
for (int k = 0; k < 53; ++k) {
|
||||||
|
for (int l = 0; l < 97; ++l) {
|
||||||
|
expected = (std::max)(expected, in(i, k, l, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VERIFY_IS_APPROX(out(i, j), expected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void test_cxx11_tensor_reduction() {
|
void test_cxx11_tensor_reduction() {
|
||||||
CALL_SUBTEST(test_simple_reductions<ColMajor>());
|
CALL_SUBTEST(test_simple_reductions<ColMajor>());
|
||||||
CALL_SUBTEST(test_simple_reductions<RowMajor>());
|
CALL_SUBTEST(test_simple_reductions<RowMajor>());
|
||||||
@ -380,8 +411,10 @@ void test_cxx11_tensor_reduction() {
|
|||||||
CALL_SUBTEST(test_tensor_maps<RowMajor>());
|
CALL_SUBTEST(test_tensor_maps<RowMajor>());
|
||||||
CALL_SUBTEST(test_static_dims<ColMajor>());
|
CALL_SUBTEST(test_static_dims<ColMajor>());
|
||||||
CALL_SUBTEST(test_static_dims<RowMajor>());
|
CALL_SUBTEST(test_static_dims<RowMajor>());
|
||||||
CALL_SUBTEST(test_innermost_last_dims<RowMajor>());
|
|
||||||
CALL_SUBTEST(test_innermost_last_dims<ColMajor>());
|
CALL_SUBTEST(test_innermost_last_dims<ColMajor>());
|
||||||
CALL_SUBTEST(test_innermost_first_dims<RowMajor>());
|
CALL_SUBTEST(test_innermost_last_dims<RowMajor>());
|
||||||
CALL_SUBTEST(test_innermost_first_dims<ColMajor>());
|
CALL_SUBTEST(test_innermost_first_dims<ColMajor>());
|
||||||
|
CALL_SUBTEST(test_innermost_first_dims<RowMajor>());
|
||||||
|
CALL_SUBTEST(test_reduce_middle_dims<ColMajor>());
|
||||||
|
CALL_SUBTEST(test_reduce_middle_dims<RowMajor>());
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user