Reworked the tensor contraction mapper code to make it compile on Android

This commit is contained in:
Benoit Steiner 2015-10-23 09:33:41 -07:00
parent 29c3b7513e
commit a586fdaa91

View File

@ -33,14 +33,14 @@ template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
int packet_size, bool inner_dim_contiguous>
class BaseTensorContractionMapper {
class SimpleTensorContractionMapper {
public:
EIGEN_DEVICE_FUNC
BaseTensorContractionMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
const nocontract_t& ij_strides,
const contract_t& contract_strides,
const contract_t& k_strides) :
SimpleTensorContractionMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
const nocontract_t& ij_strides,
const contract_t& contract_strides,
const contract_t& k_strides) :
m_tensor(tensor),
m_nocontract_strides(nocontract_strides),
m_ij_strides(ij_strides),
@ -160,104 +160,23 @@ class BaseTensorContractionMapper {
};
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
int packet_size,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
class TensorContractionInputMapper;
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
int packet_size,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
class TensorContractionSubMapper {
size_t packet_size, bool inner_dim_contiguous,
bool inner_dim_reordered, int Alignment>
class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous>
{
public:
typedef typename packet_traits<Scalar>::type Packet;
typedef typename packet_traits<Scalar>::half HalfPacket;
typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
typedef Self LinearMapper;
EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
: m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { }
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
return m_base_mapper(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
return m_base_mapper.loadPacket(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
return m_base_mapper.loadPacket(i + m_vert_offset, j + m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
return m_base_mapper.loadHalfPacket(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
}
template <typename PacketT, int AlignmentType>
EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
return loadPacket(i);
}
template <typename Packet>
bool aligned(Index /*i*/) const {
return false;
}
private:
const ParentMapper& m_base_mapper;
const Index m_vert_offset;
const Index m_horiz_offset;
};
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
int packet_size = (Tensor::PacketAccess ? packet_traits<Scalar>::size : 1),
bool inner_dim_contiguous = false, bool inner_dim_reordered = (side != Lhs), int Alignment=Unaligned>
class TensorContractionInputMapper
: public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> {
public:
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> Base;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
typedef SubMapper VectorMapper;
TensorContractionInputMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
const nocontract_t& ij_strides,
const contract_t& contract_strides,
const contract_t& k_strides)
: Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> ParentMapper;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
return SubMapper(*this, i, j);
}
EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
return VectorMapper(*this, i, j);
}
BaseTensorContractionMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
const nocontract_t& ij_strides,
const contract_t& contract_strides,
const contract_t& k_strides) :
ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
typedef typename packet_traits<Scalar>::type Packet;
typedef typename packet_traits<Scalar>::half HalfPacket;
@ -322,35 +241,23 @@ class TensorContractionInputMapper
};
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment>
: public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> {
bool inner_dim_contiguous,
bool inner_dim_reordered, int Alignment>
class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous>
{
public:
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> Base;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
typedef SubMapper VectorMapper;
TensorContractionInputMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
const nocontract_t& ij_strides,
const contract_t& contract_strides,
const contract_t& k_strides)
: Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> ParentMapper;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
return SubMapper(*this, i, j);
}
EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
return VectorMapper(*this, i, j);
}
BaseTensorContractionMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
const nocontract_t& ij_strides,
const contract_t& contract_strides,
const contract_t& k_strides) :
ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
typedef typename packet_traits<Scalar>::type Packet;
EIGEN_DEVICE_FUNC
@ -365,6 +272,106 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co
}
};
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
size_t packet_size,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
class TensorContractionInputMapper;
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
size_t packet_size,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
class TensorContractionSubMapper {
public:
typedef typename packet_traits<Scalar>::type Packet;
typedef typename packet_traits<Scalar>::half HalfPacket;
typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
typedef Self LinearMapper;
EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
: m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { }
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
return m_base_mapper(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
return m_base_mapper.loadPacket(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
return m_base_mapper.loadPacket(i + m_vert_offset, j + m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
return m_base_mapper.loadHalfPacket(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
}
template <typename PacketT, int AlignmentType>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
return loadPacket(i);
}
template <typename Packet>
EIGEN_DEVICE_FUNC bool aligned(Index) const {
return false;
}
private:
const ParentMapper& m_base_mapper;
const Index m_vert_offset;
const Index m_horiz_offset;
};
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
size_t packet_size,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
class TensorContractionInputMapper
: public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
public:
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
typedef SubMapper VectorMapper;
EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
const nocontract_t& ij_strides,
const contract_t& contract_strides,
const contract_t& k_strides)
: Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
return SubMapper(*this, i, j);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
return VectorMapper(*this, i, j);
}
};
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >