Updated the cxx11_tensor_custom_op to not require cxx11.

This commit is contained in:
Benoit Steiner 2015-12-10 20:53:44 -08:00
parent 4e324ca6ae
commit 9db8316c93

View File

@ -25,7 +25,9 @@ struct InsertZeros {
template <typename Output, typename Device>
void eval(const Tensor<float, 2>& input, Output& output, const Device& device) const
{
array<DenseIndex, 2> strides{{2, 2}};
array<DenseIndex, 2> strides;
strides[0] = 2;
strides[1] = 2;
output.stride(strides).device(device) = input;
Eigen::DSizes<DenseIndex, 2> offsets(1,1);
@ -70,7 +72,8 @@ struct BatchMatMul {
Output& output, const Device& device) const
{
typedef Tensor<float, 3>::DimensionPair DimPair;
array<DimPair, 1> dims({{DimPair(1, 0)}});
array<DimPair, 1> dims;
dims[0] = DimPair(1, 0);
for (int i = 0; i < output.dimension(2); ++i) {
output.template chip<2>(i).device(device) = input1.chip<2>(i).contract(input2.chip<2>(i), dims);
}
@ -88,9 +91,10 @@ static void test_custom_binary_op()
Tensor<float, 3> result = tensor1.customOp(tensor2, BatchMatMul());
for (int i = 0; i < 5; ++i) {
typedef Tensor<float, 3>::DimensionPair DimPair;
array<DimPair, 1> dims({{DimPair(1, 0)}});
array<DimPair, 1> dims;
dims[0] = DimPair(1, 0);
Tensor<float, 2> reference = tensor1.chip<2>(i).contract(tensor2.chip<2>(i), dims);
TensorRef<Tensor<float, 2>> val = result.chip<2>(i);
TensorRef<Tensor<float, 2> > val = result.chip<2>(i);
for (int j = 0; j < 2; ++j) {
for (int k = 0; k < 7; ++k) {
VERIFY_IS_APPROX(val(j, k), reference(j, k));