mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-02-23 18:20:47 +08:00
Improved handling of 1d tensors
This commit is contained in:
parent
2dde63499c
commit
b1789c112b
@ -48,7 +48,7 @@ class BaseTensorContractionMapper {
|
||||
m_k_strides(k_strides) { }
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE void prefetch(int /*i*/) { }
|
||||
EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { }
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Scalar operator()(Index row) const {
|
||||
@ -142,6 +142,13 @@ class BaseTensorContractionMapper {
|
||||
return IndexPair<Index>(linidx[0], linidx[1]);
|
||||
}
|
||||
|
||||
Index firstAligned(Index size) const {
|
||||
return size;
|
||||
}
|
||||
Index stride() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
protected:
|
||||
const Tensor m_tensor;
|
||||
const nocontract_t m_nocontract_strides;
|
||||
@ -202,6 +209,18 @@ class TensorContractionSubMapper {
|
||||
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
|
||||
}
|
||||
|
||||
template <typename PacketT, int AlignmentType>
|
||||
EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
return loadPacket(i);
|
||||
}
|
||||
|
||||
template <typename Packet>
|
||||
bool aligned(Index /*i*/) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
const ParentMapper& m_base_mapper;
|
||||
const Index m_vert_offset;
|
||||
@ -220,6 +239,7 @@ class TensorContractionInputMapper
|
||||
public:
|
||||
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> Base;
|
||||
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
|
||||
typedef SubMapper VectorMapper;
|
||||
|
||||
TensorContractionInputMapper(const Tensor& tensor,
|
||||
const nocontract_t& nocontract_strides,
|
||||
@ -233,6 +253,10 @@ class TensorContractionInputMapper
|
||||
return SubMapper(*this, i, j);
|
||||
}
|
||||
|
||||
EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
|
||||
return VectorMapper(*this, i, j);
|
||||
}
|
||||
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
typedef typename packet_traits<Scalar>::half HalfPacket;
|
||||
|
||||
@ -306,6 +330,7 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co
|
||||
public:
|
||||
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> Base;
|
||||
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
|
||||
typedef SubMapper VectorMapper;
|
||||
|
||||
TensorContractionInputMapper(const Tensor& tensor,
|
||||
const nocontract_t& nocontract_strides,
|
||||
@ -319,6 +344,10 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co
|
||||
return SubMapper(*this, i, j);
|
||||
}
|
||||
|
||||
EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
|
||||
return VectorMapper(*this, i, j);
|
||||
}
|
||||
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
|
||||
@ -592,41 +621,80 @@ struct TensorContractionEvaluatorBase
|
||||
if (this->m_lhs_inner_dim_contiguous) {
|
||||
if (this->m_rhs_inner_dim_contiguous) {
|
||||
if (this->m_rhs_inner_dim_reordered) {
|
||||
static_cast<const Derived*>(this)->template evalTyped<true, true, true, Unaligned>(buffer);
|
||||
static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer);
|
||||
}
|
||||
else {
|
||||
static_cast<const Derived*>(this)->template evalTyped<true, true, false, Unaligned>(buffer);
|
||||
static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (this->m_rhs_inner_dim_reordered) {
|
||||
static_cast<const Derived*>(this)->template evalTyped<true, false, true, Unaligned>(buffer);
|
||||
static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer);
|
||||
}
|
||||
else {
|
||||
static_cast<const Derived*>(this)->template evalTyped<true, false, false, Unaligned>(buffer);
|
||||
static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (this->m_rhs_inner_dim_contiguous) {
|
||||
if (this->m_rhs_inner_dim_reordered) {
|
||||
static_cast<const Derived*>(this)->template evalTyped<false, true, true, Unaligned>(buffer);
|
||||
static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer);
|
||||
}
|
||||
else {
|
||||
static_cast<const Derived*>(this)->template evalTyped<false, true, false, Unaligned>(buffer);
|
||||
static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (this->m_rhs_inner_dim_reordered) {
|
||||
static_cast<const Derived*>(this)->template evalTyped<false, false, true, Unaligned>(buffer);
|
||||
static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer);
|
||||
}
|
||||
else {
|
||||
static_cast<const Derived*>(this)->template evalTyped<false, false, false, Unaligned>(buffer);
|
||||
static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||
void evalGemv(Scalar* buffer) const {
|
||||
const Index rows = m_i_size;
|
||||
const Index cols = m_k_size;
|
||||
|
||||
typedef typename internal::remove_const<typename LeftArgType::Scalar>::type LhsScalar;
|
||||
typedef typename internal::remove_const<typename RightArgType::Scalar>::type RhsScalar;
|
||||
typedef TensorEvaluator<LeftArgType, Device> LeftEvaluator;
|
||||
typedef TensorEvaluator<RightArgType, Device> RightEvaluator;
|
||||
const int lhs_packet_size = internal::packet_traits<LhsScalar>::size;
|
||||
const int rhs_packet_size = internal::packet_traits<RhsScalar>::size;
|
||||
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
|
||||
LeftEvaluator, left_nocontract_t,
|
||||
contract_t, lhs_packet_size,
|
||||
lhs_inner_dim_contiguous,
|
||||
false, Unaligned> LhsMapper;
|
||||
|
||||
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
|
||||
RightEvaluator, right_nocontract_t,
|
||||
contract_t, rhs_packet_size,
|
||||
rhs_inner_dim_contiguous,
|
||||
rhs_inner_dim_reordered, Unaligned> RhsMapper;
|
||||
|
||||
LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
|
||||
m_left_contracting_strides, m_k_strides);
|
||||
RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides,
|
||||
m_right_contracting_strides, m_k_strides);
|
||||
|
||||
const Scalar alpha(1);
|
||||
const Index resIncr(1);
|
||||
|
||||
// zero out the result buffer (which must be of size at least rows * sizeof(Scalar)
|
||||
m_device.memset(buffer, 0, rows * sizeof(Scalar));
|
||||
|
||||
internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run(
|
||||
rows, cols, lhs, rhs,
|
||||
buffer, resIncr, alpha);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
|
||||
m_leftImpl.cleanup();
|
||||
m_rightImpl.cleanup();
|
||||
@ -707,7 +775,17 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
Base(op, device) { }
|
||||
|
||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||
EIGEN_DEVICE_FUNC void evalTyped(Scalar* buffer) const {
|
||||
void evalProduct(Scalar* buffer) const {
|
||||
if (this->m_j_size == 1) {
|
||||
this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
|
||||
return;
|
||||
}
|
||||
|
||||
evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
|
||||
}
|
||||
|
||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||
EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const {
|
||||
// columns in left side, rows in right side
|
||||
const Index k = this->m_k_size;
|
||||
|
||||
|
@ -93,7 +93,17 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
Base(op, device) {}
|
||||
|
||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||
void evalTyped(Scalar* buffer) const {
|
||||
void evalProduct(Scalar* buffer) const {
|
||||
if (this->m_j_size == 1) {
|
||||
this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
|
||||
return;
|
||||
}
|
||||
|
||||
evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
|
||||
}
|
||||
|
||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||
void evalGemm(Scalar* buffer) const {
|
||||
// columns in left side, rows in right side
|
||||
const Index k = this->m_k_size;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user