Workaround lack of support for arbitrary packet-type in Tensor by manually loading half/quarter packets in tensor contraction mapper.

This commit is contained in:
Gael Guennebaud 2019-01-30 16:48:01 +01:00
parent eb4c6bb22d
commit d586686924

View File

@ -241,8 +241,10 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar,
ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
template <typename PacketT,int AlignmentType>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE PacketT load(Index i, Index j) const {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
typename internal::enable_if<internal::unpacket_traits<PacketT>::size==packet_size,PacketT>::type
load(Index i, Index j) const
{
// whole method makes column major assumption
// don't need to add offsets for now (because operator handles that)
@ -283,6 +285,29 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar,
return pload<PacketT>(data);
}
template <typename PacketT,int AlignmentType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
typename internal::enable_if<internal::unpacket_traits<PacketT>::size!=packet_size,PacketT>::type
load(Index i, Index j) const
{
const Index requested_packet_size = internal::unpacket_traits<PacketT>::size;
EIGEN_ALIGN_MAX Scalar data[requested_packet_size];
const IndexPair<Index> indexPair = this->computeIndexPair(i, j, requested_packet_size - 1);
const Index first = indexPair.first;
const Index lastIdx = indexPair.second;
data[0] = this->m_tensor.coeff(first);
for (Index k = 1; k < requested_packet_size - 1; k += 2) {
const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
data[k] = this->m_tensor.coeff(internal_pair.first);
data[k + 1] = this->m_tensor.coeff(internal_pair.second);
}
data[requested_packet_size - 1] = this->m_tensor.coeff(lastIdx);
return pload<PacketT>(data);
}
template <typename PacketT,int AlignmentType>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {