Improved handling of 1d tensors

This commit is contained in:
Benoit Steiner 2014-11-03 08:51:33 -08:00
parent 2dde63499c
commit b1789c112b
2 changed files with 99 additions and 11 deletions

View File

@ -48,7 +48,7 @@ class BaseTensorContractionMapper {
m_k_strides(k_strides) { }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE void prefetch(int /*i*/) { }
EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar operator()(Index row) const {
@ -142,6 +142,13 @@ class BaseTensorContractionMapper {
return IndexPair<Index>(linidx[0], linidx[1]);
}
Index firstAligned(Index size) const {
return size;
}
Index stride() const {
return 1;
}
protected:
const Tensor m_tensor;
const nocontract_t m_nocontract_strides;
@ -202,6 +209,18 @@ class TensorContractionSubMapper {
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
}
template <typename PacketT, int AlignmentType>
EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
return loadPacket(i);
}
template <typename Packet>
bool aligned(Index /*i*/) const {
return false;
}
private:
const ParentMapper& m_base_mapper;
const Index m_vert_offset;
@ -220,6 +239,7 @@ class TensorContractionInputMapper
public:
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> Base;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
typedef SubMapper VectorMapper;
TensorContractionInputMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
@ -233,6 +253,10 @@ class TensorContractionInputMapper
return SubMapper(*this, i, j);
}
EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
return VectorMapper(*this, i, j);
}
typedef typename packet_traits<Scalar>::type Packet;
typedef typename packet_traits<Scalar>::half HalfPacket;
@ -306,6 +330,7 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co
public:
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> Base;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
typedef SubMapper VectorMapper;
TensorContractionInputMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
@ -319,6 +344,10 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co
return SubMapper(*this, i, j);
}
EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
return VectorMapper(*this, i, j);
}
typedef typename packet_traits<Scalar>::type Packet;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
@ -592,41 +621,80 @@ struct TensorContractionEvaluatorBase
if (this->m_lhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_reordered) {
static_cast<const Derived*>(this)->template evalTyped<true, true, true, Unaligned>(buffer);
static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer);
}
else {
static_cast<const Derived*>(this)->template evalTyped<true, true, false, Unaligned>(buffer);
static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer);
}
}
else {
if (this->m_rhs_inner_dim_reordered) {
static_cast<const Derived*>(this)->template evalTyped<true, false, true, Unaligned>(buffer);
static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer);
}
else {
static_cast<const Derived*>(this)->template evalTyped<true, false, false, Unaligned>(buffer);
static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer);
}
}
}
else {
if (this->m_rhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_reordered) {
static_cast<const Derived*>(this)->template evalTyped<false, true, true, Unaligned>(buffer);
static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer);
}
else {
static_cast<const Derived*>(this)->template evalTyped<false, true, false, Unaligned>(buffer);
static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer);
}
}
else {
if (this->m_rhs_inner_dim_reordered) {
static_cast<const Derived*>(this)->template evalTyped<false, false, true, Unaligned>(buffer);
static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer);
}
else {
static_cast<const Derived*>(this)->template evalTyped<false, false, false, Unaligned>(buffer);
static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer);
}
}
}
}
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
void evalGemv(Scalar* buffer) const {
const Index rows = m_i_size;
const Index cols = m_k_size;
typedef typename internal::remove_const<typename LeftArgType::Scalar>::type LhsScalar;
typedef typename internal::remove_const<typename RightArgType::Scalar>::type RhsScalar;
typedef TensorEvaluator<LeftArgType, Device> LeftEvaluator;
typedef TensorEvaluator<RightArgType, Device> RightEvaluator;
const int lhs_packet_size = internal::packet_traits<LhsScalar>::size;
const int rhs_packet_size = internal::packet_traits<RhsScalar>::size;
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
LeftEvaluator, left_nocontract_t,
contract_t, lhs_packet_size,
lhs_inner_dim_contiguous,
false, Unaligned> LhsMapper;
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
RightEvaluator, right_nocontract_t,
contract_t, rhs_packet_size,
rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, Unaligned> RhsMapper;
LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
m_left_contracting_strides, m_k_strides);
RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides,
m_right_contracting_strides, m_k_strides);
const Scalar alpha(1);
const Index resIncr(1);
// zero out the result buffer (which must be of size at least rows * sizeof(Scalar)
m_device.memset(buffer, 0, rows * sizeof(Scalar));
internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run(
rows, cols, lhs, rhs,
buffer, resIncr, alpha);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
m_leftImpl.cleanup();
m_rightImpl.cleanup();
@ -707,7 +775,17 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Base(op, device) { }
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
EIGEN_DEVICE_FUNC void evalTyped(Scalar* buffer) const {
void evalProduct(Scalar* buffer) const {
if (this->m_j_size == 1) {
this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
return;
}
evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
}
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const {
// columns in left side, rows in right side
const Index k = this->m_k_size;

View File

@ -93,7 +93,17 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Base(op, device) {}
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
void evalTyped(Scalar* buffer) const {
void evalProduct(Scalar* buffer) const {
if (this->m_j_size == 1) {
this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
return;
}
evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
}
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
void evalGemm(Scalar* buffer) const {
// columns in left side, rows in right side
const Index k = this->m_k_size;