mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Reworked the tensor contraction mapper code to make it compile on Android
This commit is contained in:
parent
29c3b7513e
commit
a586fdaa91
@ -33,14 +33,14 @@ template<typename Scalar, typename Index, int side,
|
||||
typename Tensor,
|
||||
typename nocontract_t, typename contract_t,
|
||||
int packet_size, bool inner_dim_contiguous>
|
||||
class BaseTensorContractionMapper {
|
||||
class SimpleTensorContractionMapper {
|
||||
public:
|
||||
EIGEN_DEVICE_FUNC
|
||||
BaseTensorContractionMapper(const Tensor& tensor,
|
||||
const nocontract_t& nocontract_strides,
|
||||
const nocontract_t& ij_strides,
|
||||
const contract_t& contract_strides,
|
||||
const contract_t& k_strides) :
|
||||
SimpleTensorContractionMapper(const Tensor& tensor,
|
||||
const nocontract_t& nocontract_strides,
|
||||
const nocontract_t& ij_strides,
|
||||
const contract_t& contract_strides,
|
||||
const contract_t& k_strides) :
|
||||
m_tensor(tensor),
|
||||
m_nocontract_strides(nocontract_strides),
|
||||
m_ij_strides(ij_strides),
|
||||
@ -160,104 +160,23 @@ class BaseTensorContractionMapper {
|
||||
};
|
||||
|
||||
|
||||
|
||||
template<typename Scalar, typename Index, int side,
|
||||
typename Tensor,
|
||||
typename nocontract_t, typename contract_t,
|
||||
int packet_size,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||
class TensorContractionInputMapper;
|
||||
|
||||
template<typename Scalar, typename Index, int side,
|
||||
typename Tensor,
|
||||
typename nocontract_t, typename contract_t,
|
||||
int packet_size,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||
class TensorContractionSubMapper {
|
||||
size_t packet_size, bool inner_dim_contiguous,
|
||||
bool inner_dim_reordered, int Alignment>
|
||||
class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous>
|
||||
{
|
||||
public:
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
typedef typename packet_traits<Scalar>::half HalfPacket;
|
||||
|
||||
typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
|
||||
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
|
||||
typedef Self LinearMapper;
|
||||
|
||||
EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
|
||||
: m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
|
||||
return m_base_mapper(i + m_vert_offset, m_horiz_offset);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
|
||||
return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
|
||||
return m_base_mapper.loadPacket(i + m_vert_offset, m_horiz_offset);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
|
||||
return m_base_mapper.loadPacket(i + m_vert_offset, j + m_horiz_offset);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
|
||||
return m_base_mapper.loadHalfPacket(i + m_vert_offset, m_horiz_offset);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
|
||||
m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
|
||||
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
|
||||
}
|
||||
|
||||
template <typename PacketT, int AlignmentType>
|
||||
EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
return loadPacket(i);
|
||||
}
|
||||
|
||||
template <typename Packet>
|
||||
bool aligned(Index /*i*/) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
const ParentMapper& m_base_mapper;
|
||||
const Index m_vert_offset;
|
||||
const Index m_horiz_offset;
|
||||
};
|
||||
|
||||
|
||||
template<typename Scalar, typename Index, int side,
|
||||
typename Tensor,
|
||||
typename nocontract_t, typename contract_t,
|
||||
int packet_size = (Tensor::PacketAccess ? packet_traits<Scalar>::size : 1),
|
||||
bool inner_dim_contiguous = false, bool inner_dim_reordered = (side != Lhs), int Alignment=Unaligned>
|
||||
class TensorContractionInputMapper
|
||||
: public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> {
|
||||
|
||||
public:
|
||||
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> Base;
|
||||
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
|
||||
typedef SubMapper VectorMapper;
|
||||
|
||||
TensorContractionInputMapper(const Tensor& tensor,
|
||||
const nocontract_t& nocontract_strides,
|
||||
const nocontract_t& ij_strides,
|
||||
const contract_t& contract_strides,
|
||||
const contract_t& k_strides)
|
||||
: Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
||||
typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> ParentMapper;
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
|
||||
return SubMapper(*this, i, j);
|
||||
}
|
||||
|
||||
EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
|
||||
return VectorMapper(*this, i, j);
|
||||
}
|
||||
BaseTensorContractionMapper(const Tensor& tensor,
|
||||
const nocontract_t& nocontract_strides,
|
||||
const nocontract_t& ij_strides,
|
||||
const contract_t& contract_strides,
|
||||
const contract_t& k_strides) :
|
||||
ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
||||
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
typedef typename packet_traits<Scalar>::half HalfPacket;
|
||||
@ -322,35 +241,23 @@ class TensorContractionInputMapper
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
template<typename Scalar, typename Index, int side,
|
||||
typename Tensor,
|
||||
typename nocontract_t, typename contract_t,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||
class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment>
|
||||
: public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> {
|
||||
|
||||
bool inner_dim_contiguous,
|
||||
bool inner_dim_reordered, int Alignment>
|
||||
class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous>
|
||||
{
|
||||
public:
|
||||
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> Base;
|
||||
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
|
||||
typedef SubMapper VectorMapper;
|
||||
|
||||
TensorContractionInputMapper(const Tensor& tensor,
|
||||
const nocontract_t& nocontract_strides,
|
||||
const nocontract_t& ij_strides,
|
||||
const contract_t& contract_strides,
|
||||
const contract_t& k_strides)
|
||||
: Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
||||
typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> ParentMapper;
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
|
||||
return SubMapper(*this, i, j);
|
||||
}
|
||||
|
||||
EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
|
||||
return VectorMapper(*this, i, j);
|
||||
}
|
||||
BaseTensorContractionMapper(const Tensor& tensor,
|
||||
const nocontract_t& nocontract_strides,
|
||||
const nocontract_t& ij_strides,
|
||||
const contract_t& contract_strides,
|
||||
const contract_t& k_strides) :
|
||||
ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
||||
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
EIGEN_DEVICE_FUNC
|
||||
@ -365,6 +272,106 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Scalar, typename Index, int side,
|
||||
typename Tensor,
|
||||
typename nocontract_t, typename contract_t,
|
||||
size_t packet_size,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||
class TensorContractionInputMapper;
|
||||
|
||||
template<typename Scalar, typename Index, int side,
|
||||
typename Tensor,
|
||||
typename nocontract_t, typename contract_t,
|
||||
size_t packet_size,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||
class TensorContractionSubMapper {
|
||||
public:
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
typedef typename packet_traits<Scalar>::half HalfPacket;
|
||||
|
||||
typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
|
||||
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
|
||||
typedef Self LinearMapper;
|
||||
|
||||
EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
|
||||
: m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
|
||||
return m_base_mapper(i + m_vert_offset, m_horiz_offset);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
|
||||
return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
|
||||
return m_base_mapper.loadPacket(i + m_vert_offset, m_horiz_offset);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
|
||||
return m_base_mapper.loadPacket(i + m_vert_offset, j + m_horiz_offset);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
|
||||
return m_base_mapper.loadHalfPacket(i + m_vert_offset, m_horiz_offset);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
|
||||
m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
|
||||
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
|
||||
}
|
||||
|
||||
template <typename PacketT, int AlignmentType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
return loadPacket(i);
|
||||
}
|
||||
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC bool aligned(Index) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
const ParentMapper& m_base_mapper;
|
||||
const Index m_vert_offset;
|
||||
const Index m_horiz_offset;
|
||||
};
|
||||
|
||||
|
||||
template<typename Scalar, typename Index, int side,
|
||||
typename Tensor,
|
||||
typename nocontract_t, typename contract_t,
|
||||
size_t packet_size,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||
class TensorContractionInputMapper
|
||||
: public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
|
||||
|
||||
public:
|
||||
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
|
||||
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
|
||||
typedef SubMapper VectorMapper;
|
||||
|
||||
EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor,
|
||||
const nocontract_t& nocontract_strides,
|
||||
const nocontract_t& ij_strides,
|
||||
const contract_t& contract_strides,
|
||||
const contract_t& k_strides)
|
||||
: Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
|
||||
return SubMapper(*this, i, j);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
|
||||
return VectorMapper(*this, i, j);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
|
||||
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
|
||||
|
Loading…
Reference in New Issue
Block a user