From d586686924c2783f56bd514c9365afeecc3e84f6 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Wed, 30 Jan 2019 16:48:01 +0100 Subject: [PATCH] Workaround lack of support for arbitrary packet-type in Tensor by manually loading half/quarter packets in tensor contraction mapper. --- .../src/Tensor/TensorContractionMapper.h | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h index 64dfcd297..142492603 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h @@ -241,8 +241,10 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE PacketT load(Index i, Index j) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + typename internal::enable_if::size==packet_size,PacketT>::type + load(Index i, Index j) const + { // whole method makes column major assumption // don't need to add offsets for now (because operator handles that) @@ -283,6 +285,29 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper(data); } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + typename internal::enable_if::size!=packet_size,PacketT>::type + load(Index i, Index j) const + { + const Index requested_packet_size = internal::unpacket_traits::size; + EIGEN_ALIGN_MAX Scalar data[requested_packet_size]; + + const IndexPair indexPair = this->computeIndexPair(i, j, requested_packet_size - 1); + const Index first = indexPair.first; + const Index lastIdx = indexPair.second; + + data[0] = this->m_tensor.coeff(first); + for (Index k = 1; k < requested_packet_size - 1; k += 2) { + const IndexPair internal_pair = this->computeIndexPair(i + k, j, 1); + data[k] = this->m_tensor.coeff(internal_pair.first); + data[k + 1] = this->m_tensor.coeff(internal_pair.second); + } + data[requested_packet_size - 1] = this->m_tensor.coeff(lastIdx); + + return pload(data); + } + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {