bug #1567: add optimized path for tensor broadcasting and 'Channel First' shape

This commit is contained in:
Gael Guennebaud 2018-07-09 11:23:16 +02:00
parent ec323b7e66
commit 6190aa5632
2 changed files with 123 additions and 4 deletions

View File

@ -161,6 +161,22 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
}
}
}
// Handle special format like NCHW, its input shape is '[1, N..., 1]' and
// broadcast shape is '[N, 1..., N]'
if (!oneByN && !nByOne) {
if (input_dims[0] == 1 && input_dims[NumDims-1] == 1 && NumDims > 2) {
nByOne = true;
oneByN = true;
for (int i = 1; i < NumDims-1; ++i) {
if (broadcast[i] != 1) {
nByOne = false;
oneByN = false;
break;
}
}
}
}
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
@ -256,24 +272,70 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
}
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
if (oneByN) {
if (oneByN && !nByOne) {
return packetNByOne<LoadMode>(index);
} else if (nByOne) {
} else if (!oneByN && nByOne) {
return packetOneByN<LoadMode>(index);
} else if (oneByN && nByOne) {
return packetOneByNByOne<LoadMode>(index);
} else {
return packetColMajor<LoadMode>(index);
}
} else {
if (oneByN) {
if (oneByN && !nByOne) {
return packetOneByN<LoadMode>(index);
} else if (nByOne) {
} else if (!oneByN && nByOne) {
return packetNByOne<LoadMode>(index);
} else if (oneByN && nByOne) {
return packetOneByNByOne<LoadMode>(index);
} else {
return packetRowMajor<LoadMode>(index);
}
}
}
template<int LoadMode>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByNByOne
(Index index) const
{
EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
Index startDim, endDim;
Index inputIndex, outputOffset, batchedIndex;
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
startDim = NumDims - 1;
endDim = 1;
} else {
startDim = 0;
endDim = NumDims - 2;
}
batchedIndex = index % m_outputStrides[startDim];
inputIndex = batchedIndex / m_outputStrides[endDim];
outputOffset = batchedIndex % m_outputStrides[endDim];
if (outputOffset + PacketSize <= m_outputStrides[endDim]) {
values[0] = m_impl.coeff(inputIndex);
return internal::pload1<PacketReturnType>(values);
} else {
for (int i = 0, cur = 0; i < PacketSize; ++i, ++cur) {
if (outputOffset + cur < m_outputStrides[endDim]) {
values[i] = m_impl.coeff(inputIndex);
} else {
++inputIndex;
inputIndex = (inputIndex == m_inputStrides[startDim] ? 0 : inputIndex);
values[i] = m_impl.coeff(inputIndex);
outputOffset = 0;
cur = 0;
}
}
return internal::pload<PacketReturnType>(values);
}
}
template<int LoadMode>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByN(Index index) const
{

View File

@ -238,6 +238,59 @@ static void test_simple_broadcasting_n_by_one()
}
}
template <int DataLayout>
static void test_simple_broadcasting_one_by_n_by_one_1d()
{
Tensor<float, 3, DataLayout> tensor(1,7,1);
tensor.setRandom();
array<ptrdiff_t, 3> broadcasts;
broadcasts[0] = 5;
broadcasts[1] = 1;
broadcasts[2] = 13;
Tensor<float, 3, DataLayout> broadcasted;
broadcasted = tensor.broadcast(broadcasts);
VERIFY_IS_EQUAL(broadcasted.dimension(0), 5);
VERIFY_IS_EQUAL(broadcasted.dimension(1), 7);
VERIFY_IS_EQUAL(broadcasted.dimension(2), 13);
for (int i = 0; i < 5; ++i) {
for (int j = 0; j < 7; ++j) {
for (int k = 0; k < 13; ++k) {
VERIFY_IS_EQUAL(tensor(0,j%7,0), broadcasted(i,j,k));
}
}
}
}
template <int DataLayout>
static void test_simple_broadcasting_one_by_n_by_one_2d()
{
Tensor<float, 4, DataLayout> tensor(1,7,13,1);
tensor.setRandom();
array<ptrdiff_t, 4> broadcasts;
broadcasts[0] = 5;
broadcasts[1] = 1;
broadcasts[2] = 1;
broadcasts[3] = 19;
Tensor<float, 4, DataLayout> broadcast;
broadcast = tensor.broadcast(broadcasts);
VERIFY_IS_EQUAL(broadcast.dimension(0), 5);
VERIFY_IS_EQUAL(broadcast.dimension(1), 7);
VERIFY_IS_EQUAL(broadcast.dimension(2), 13);
VERIFY_IS_EQUAL(broadcast.dimension(3), 19);
for (int i = 0; i < 5; ++i) {
for (int j = 0; j < 7; ++j) {
for (int k = 0; k < 13; ++k) {
for (int l = 0; l < 19; ++l) {
VERIFY_IS_EQUAL(tensor(0,j%7,k%13,0), broadcast(i,j,k,l));
}
}
}
}
}
void test_cxx11_tensor_broadcasting()
{
@ -253,4 +306,8 @@ void test_cxx11_tensor_broadcasting()
CALL_SUBTEST(test_simple_broadcasting_n_by_one<RowMajor>());
CALL_SUBTEST(test_simple_broadcasting_one_by_n<ColMajor>());
CALL_SUBTEST(test_simple_broadcasting_n_by_one<ColMajor>());
CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<ColMajor>());
CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<ColMajor>());
CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<RowMajor>());
CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<RowMajor>());
}