mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-18 14:34:17 +08:00
bug #1567: add optimized path for tensor broadcasting and 'Channel First' shape
This commit is contained in:
parent
ec323b7e66
commit
6190aa5632
@ -161,6 +161,22 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle special format like NCHW, its input shape is '[1, N..., 1]' and
|
||||
// broadcast shape is '[N, 1..., N]'
|
||||
if (!oneByN && !nByOne) {
|
||||
if (input_dims[0] == 1 && input_dims[NumDims-1] == 1 && NumDims > 2) {
|
||||
nByOne = true;
|
||||
oneByN = true;
|
||||
for (int i = 1; i < NumDims-1; ++i) {
|
||||
if (broadcast[i] != 1) {
|
||||
nByOne = false;
|
||||
oneByN = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
|
||||
@ -256,24 +272,70 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
|
||||
}
|
||||
|
||||
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||
if (oneByN) {
|
||||
if (oneByN && !nByOne) {
|
||||
return packetNByOne<LoadMode>(index);
|
||||
} else if (nByOne) {
|
||||
} else if (!oneByN && nByOne) {
|
||||
return packetOneByN<LoadMode>(index);
|
||||
} else if (oneByN && nByOne) {
|
||||
return packetOneByNByOne<LoadMode>(index);
|
||||
} else {
|
||||
return packetColMajor<LoadMode>(index);
|
||||
}
|
||||
} else {
|
||||
if (oneByN) {
|
||||
if (oneByN && !nByOne) {
|
||||
return packetOneByN<LoadMode>(index);
|
||||
} else if (nByOne) {
|
||||
} else if (!oneByN && nByOne) {
|
||||
return packetNByOne<LoadMode>(index);
|
||||
} else if (oneByN && nByOne) {
|
||||
return packetOneByNByOne<LoadMode>(index);
|
||||
} else {
|
||||
return packetRowMajor<LoadMode>(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByNByOne
|
||||
(Index index) const
|
||||
{
|
||||
EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
|
||||
eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
|
||||
|
||||
EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
|
||||
Index startDim, endDim;
|
||||
Index inputIndex, outputOffset, batchedIndex;
|
||||
|
||||
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||
startDim = NumDims - 1;
|
||||
endDim = 1;
|
||||
} else {
|
||||
startDim = 0;
|
||||
endDim = NumDims - 2;
|
||||
}
|
||||
|
||||
batchedIndex = index % m_outputStrides[startDim];
|
||||
inputIndex = batchedIndex / m_outputStrides[endDim];
|
||||
outputOffset = batchedIndex % m_outputStrides[endDim];
|
||||
|
||||
if (outputOffset + PacketSize <= m_outputStrides[endDim]) {
|
||||
values[0] = m_impl.coeff(inputIndex);
|
||||
return internal::pload1<PacketReturnType>(values);
|
||||
} else {
|
||||
for (int i = 0, cur = 0; i < PacketSize; ++i, ++cur) {
|
||||
if (outputOffset + cur < m_outputStrides[endDim]) {
|
||||
values[i] = m_impl.coeff(inputIndex);
|
||||
} else {
|
||||
++inputIndex;
|
||||
inputIndex = (inputIndex == m_inputStrides[startDim] ? 0 : inputIndex);
|
||||
values[i] = m_impl.coeff(inputIndex);
|
||||
outputOffset = 0;
|
||||
cur = 0;
|
||||
}
|
||||
}
|
||||
return internal::pload<PacketReturnType>(values);
|
||||
}
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByN(Index index) const
|
||||
{
|
||||
|
@ -238,6 +238,59 @@ static void test_simple_broadcasting_n_by_one()
|
||||
}
|
||||
}
|
||||
|
||||
template <int DataLayout>
|
||||
static void test_simple_broadcasting_one_by_n_by_one_1d()
|
||||
{
|
||||
Tensor<float, 3, DataLayout> tensor(1,7,1);
|
||||
tensor.setRandom();
|
||||
array<ptrdiff_t, 3> broadcasts;
|
||||
broadcasts[0] = 5;
|
||||
broadcasts[1] = 1;
|
||||
broadcasts[2] = 13;
|
||||
Tensor<float, 3, DataLayout> broadcasted;
|
||||
broadcasted = tensor.broadcast(broadcasts);
|
||||
|
||||
VERIFY_IS_EQUAL(broadcasted.dimension(0), 5);
|
||||
VERIFY_IS_EQUAL(broadcasted.dimension(1), 7);
|
||||
VERIFY_IS_EQUAL(broadcasted.dimension(2), 13);
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
for (int j = 0; j < 7; ++j) {
|
||||
for (int k = 0; k < 13; ++k) {
|
||||
VERIFY_IS_EQUAL(tensor(0,j%7,0), broadcasted(i,j,k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int DataLayout>
|
||||
static void test_simple_broadcasting_one_by_n_by_one_2d()
|
||||
{
|
||||
Tensor<float, 4, DataLayout> tensor(1,7,13,1);
|
||||
tensor.setRandom();
|
||||
array<ptrdiff_t, 4> broadcasts;
|
||||
broadcasts[0] = 5;
|
||||
broadcasts[1] = 1;
|
||||
broadcasts[2] = 1;
|
||||
broadcasts[3] = 19;
|
||||
Tensor<float, 4, DataLayout> broadcast;
|
||||
broadcast = tensor.broadcast(broadcasts);
|
||||
|
||||
VERIFY_IS_EQUAL(broadcast.dimension(0), 5);
|
||||
VERIFY_IS_EQUAL(broadcast.dimension(1), 7);
|
||||
VERIFY_IS_EQUAL(broadcast.dimension(2), 13);
|
||||
VERIFY_IS_EQUAL(broadcast.dimension(3), 19);
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
for (int j = 0; j < 7; ++j) {
|
||||
for (int k = 0; k < 13; ++k) {
|
||||
for (int l = 0; l < 19; ++l) {
|
||||
VERIFY_IS_EQUAL(tensor(0,j%7,k%13,0), broadcast(i,j,k,l));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void test_cxx11_tensor_broadcasting()
|
||||
{
|
||||
@ -253,4 +306,8 @@ void test_cxx11_tensor_broadcasting()
|
||||
CALL_SUBTEST(test_simple_broadcasting_n_by_one<RowMajor>());
|
||||
CALL_SUBTEST(test_simple_broadcasting_one_by_n<ColMajor>());
|
||||
CALL_SUBTEST(test_simple_broadcasting_n_by_one<ColMajor>());
|
||||
CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<ColMajor>());
|
||||
CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<ColMajor>());
|
||||
CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<RowMajor>());
|
||||
CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<RowMajor>());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user