Update the padding computation for PADDING_SAME to be consistent with TensorFlow.

This commit is contained in:
Benoit Steiner 2018-01-30 20:22:12 +00:00
commit 8f55956a57
2 changed files with 56 additions and 0 deletions

View File

@ -265,6 +265,10 @@ struct TensorEvaluator<const TensorImagePatchOp<Rows, Cols, ArgType>, Device>
// Calculate the padding
m_rowPaddingTop = ((m_outputRows - 1) * m_row_strides + m_patch_rows_eff - m_input_rows_eff) / 2;
m_colPaddingLeft = ((m_outputCols - 1) * m_col_strides + m_patch_cols_eff - m_input_cols_eff) / 2;
// The padding size calculation for PADDING_SAME has been updated to
// be consistent with how TensorFlow extracts its paddings.
m_rowPaddingTop = numext::maxi<Index>(0, m_rowPaddingTop);
m_colPaddingLeft = numext::maxi<Index>(0, m_colPaddingLeft);
break;
default:
eigen_assert(false && "unexpected padding");

View File

@ -405,6 +405,57 @@ void test_patch_padding_same()
}
}
// Verifies that SAME padding, when computed as negative values, will be clipped
// to zero.
void test_patch_padding_same_negative_padding_clip_to_zero() {
int input_depth = 1;
int input_rows = 15;
int input_cols = 1;
int input_batches = 1;
int ksize = 1; // Corresponds to the Rows and Cols for
// tensor.extract_image_patches<>.
int row_stride = 5;
int col_stride = 1;
// ColMajor
Tensor<float, 4> tensor(input_depth, input_rows, input_cols, input_batches);
// Initializes tensor with incrementing numbers.
for (int i = 0; i < tensor.size(); ++i) {
tensor.data()[i] = i + 1;
}
Tensor<float, 5> result = tensor.extract_image_patches(
ksize, ksize, row_stride, col_stride, 1, 1, PADDING_SAME);
// row padding will be computed as -2 originally and then be clipped to 0.
VERIFY_IS_EQUAL(result.coeff(0), 1.0f);
VERIFY_IS_EQUAL(result.coeff(1), 6.0f);
VERIFY_IS_EQUAL(result.coeff(2), 11.0f);
VERIFY_IS_EQUAL(result.dimension(0), input_depth); // depth
VERIFY_IS_EQUAL(result.dimension(1), ksize); // kernel rows
VERIFY_IS_EQUAL(result.dimension(2), ksize); // kernel cols
VERIFY_IS_EQUAL(result.dimension(3), 3); // number of patches
VERIFY_IS_EQUAL(result.dimension(4), input_batches); // number of batches
// RowMajor
Tensor<float, 4, RowMajor> tensor_row_major = tensor.swap_layout();
VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3));
VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2));
VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1));
VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0));
Tensor<float, 5, RowMajor> result_row_major =
tensor_row_major.extract_image_patches(ksize, ksize, row_stride,
col_stride, 1, 1, PADDING_SAME);
VERIFY_IS_EQUAL(result_row_major.coeff(0), 1.0f);
VERIFY_IS_EQUAL(result_row_major.coeff(1), 6.0f);
VERIFY_IS_EQUAL(result_row_major.coeff(2), 11.0f);
VERIFY_IS_EQUAL(result.dimension(0), result_row_major.dimension(4));
VERIFY_IS_EQUAL(result.dimension(1), result_row_major.dimension(3));
VERIFY_IS_EQUAL(result.dimension(2), result_row_major.dimension(2));
VERIFY_IS_EQUAL(result.dimension(3), result_row_major.dimension(1));
VERIFY_IS_EQUAL(result.dimension(4), result_row_major.dimension(0));
}
void test_patch_no_extra_dim()
{
Tensor<float, 3> tensor(2,3,5);
@ -754,4 +805,5 @@ void test_cxx11_tensor_image_patch()
CALL_SUBTEST_4(test_patch_padding_valid_same_value());
CALL_SUBTEST_5(test_patch_padding_same());
CALL_SUBTEST_6(test_imagenet_patches());
CALL_SUBTEST_7(test_patch_padding_same_negative_padding_clip_to_zero());
}