mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-15 07:10:37 +08:00
Workaround lack of support for arbitrary packet-type in Tensor by manually loading half/quarter packets in tensor contraction mapper.
This commit is contained in:
parent
eb4c6bb22d
commit
d586686924
@ -241,8 +241,10 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar,
|
||||
ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
||||
|
||||
template <typename PacketT,int AlignmentType>
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE PacketT load(Index i, Index j) const {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
typename internal::enable_if<internal::unpacket_traits<PacketT>::size==packet_size,PacketT>::type
|
||||
load(Index i, Index j) const
|
||||
{
|
||||
// whole method makes column major assumption
|
||||
|
||||
// don't need to add offsets for now (because operator handles that)
|
||||
@ -283,6 +285,29 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar,
|
||||
return pload<PacketT>(data);
|
||||
}
|
||||
|
||||
template <typename PacketT,int AlignmentType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
typename internal::enable_if<internal::unpacket_traits<PacketT>::size!=packet_size,PacketT>::type
|
||||
load(Index i, Index j) const
|
||||
{
|
||||
const Index requested_packet_size = internal::unpacket_traits<PacketT>::size;
|
||||
EIGEN_ALIGN_MAX Scalar data[requested_packet_size];
|
||||
|
||||
const IndexPair<Index> indexPair = this->computeIndexPair(i, j, requested_packet_size - 1);
|
||||
const Index first = indexPair.first;
|
||||
const Index lastIdx = indexPair.second;
|
||||
|
||||
data[0] = this->m_tensor.coeff(first);
|
||||
for (Index k = 1; k < requested_packet_size - 1; k += 2) {
|
||||
const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
|
||||
data[k] = this->m_tensor.coeff(internal_pair.first);
|
||||
data[k + 1] = this->m_tensor.coeff(internal_pair.second);
|
||||
}
|
||||
data[requested_packet_size - 1] = this->m_tensor.coeff(lastIdx);
|
||||
|
||||
return pload<PacketT>(data);
|
||||
}
|
||||
|
||||
template <typename PacketT,int AlignmentType>
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {
|
||||
|
Loading…
Reference in New Issue
Block a user