mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Workaround lack of support for arbitrary packet-type in Tensor by manually loading half/quarter packets in tensor contraction mapper.
This commit is contained in:
parent
eb4c6bb22d
commit
d586686924
@ -241,8 +241,10 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar,
|
|||||||
ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
||||||
|
|
||||||
template <typename PacketT,int AlignmentType>
|
template <typename PacketT,int AlignmentType>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
EIGEN_STRONG_INLINE PacketT load(Index i, Index j) const {
|
typename internal::enable_if<internal::unpacket_traits<PacketT>::size==packet_size,PacketT>::type
|
||||||
|
load(Index i, Index j) const
|
||||||
|
{
|
||||||
// whole method makes column major assumption
|
// whole method makes column major assumption
|
||||||
|
|
||||||
// don't need to add offsets for now (because operator handles that)
|
// don't need to add offsets for now (because operator handles that)
|
||||||
@ -283,6 +285,29 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar,
|
|||||||
return pload<PacketT>(data);
|
return pload<PacketT>(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename PacketT,int AlignmentType>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
typename internal::enable_if<internal::unpacket_traits<PacketT>::size!=packet_size,PacketT>::type
|
||||||
|
load(Index i, Index j) const
|
||||||
|
{
|
||||||
|
const Index requested_packet_size = internal::unpacket_traits<PacketT>::size;
|
||||||
|
EIGEN_ALIGN_MAX Scalar data[requested_packet_size];
|
||||||
|
|
||||||
|
const IndexPair<Index> indexPair = this->computeIndexPair(i, j, requested_packet_size - 1);
|
||||||
|
const Index first = indexPair.first;
|
||||||
|
const Index lastIdx = indexPair.second;
|
||||||
|
|
||||||
|
data[0] = this->m_tensor.coeff(first);
|
||||||
|
for (Index k = 1; k < requested_packet_size - 1; k += 2) {
|
||||||
|
const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
|
||||||
|
data[k] = this->m_tensor.coeff(internal_pair.first);
|
||||||
|
data[k + 1] = this->m_tensor.coeff(internal_pair.second);
|
||||||
|
}
|
||||||
|
data[requested_packet_size - 1] = this->m_tensor.coeff(lastIdx);
|
||||||
|
|
||||||
|
return pload<PacketT>(data);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename PacketT,int AlignmentType>
|
template <typename PacketT,int AlignmentType>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {
|
EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {
|
||||||
|
Loading…
Reference in New Issue
Block a user