mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-02-23 18:20:47 +08:00
Added support for user defined custom tensor op.
This commit is contained in:
parent
dc31fcb9ba
commit
f1f480b116
@ -91,6 +91,7 @@
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h"
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h"
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h"
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h"
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h"
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h"
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h"
|
||||
|
@ -481,6 +481,18 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
return TensorStridingOp<const Strides, const Derived>(derived(), strides);
|
||||
}
|
||||
|
||||
// Added support for custom unary and binary operations
|
||||
template <typename CustomUnaryFunc>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorCustomUnaryOp<const CustomUnaryFunc, const Derived> customOp(const CustomUnaryFunc& op) const {
|
||||
return TensorCustomUnaryOp<const CustomUnaryFunc, const Derived>(derived(), op);
|
||||
}
|
||||
template <typename OtherDerived, typename CustomBinaryFunc>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorCustomBinaryOp<const CustomBinaryFunc, const Derived, const OtherDerived> customOp(const OtherDerived& other, const CustomBinaryFunc& op) const {
|
||||
return TensorCustomBinaryOp<const CustomBinaryFunc, const Derived, const OtherDerived>(derived(), other, op);
|
||||
}
|
||||
|
||||
// Force the evaluation of the expression.
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorForcedEvalOp<const Derived> eval() const {
|
||||
|
310
unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h
Normal file
310
unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h
Normal file
@ -0,0 +1,310 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
|
||||
#define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
/** \class TensorCustomUnaryOp
|
||||
* \ingroup CXX11_Tensor_Module
|
||||
*
|
||||
* \brief Tensor custom class.
|
||||
*
|
||||
*
|
||||
*/
|
||||
namespace internal {
|
||||
template<typename CustomUnaryFunc, typename XprType>
|
||||
struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
|
||||
{
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
typedef typename XprType::StorageKind StorageKind;
|
||||
typedef typename XprType::Index Index;
|
||||
typedef typename XprType::Nested Nested;
|
||||
typedef typename remove_reference<Nested>::type _Nested;
|
||||
static const int NumDimensions = traits<XprType>::NumDimensions;
|
||||
static const int Layout = traits<XprType>::Layout;
|
||||
};
|
||||
|
||||
template<typename CustomUnaryFunc, typename XprType>
|
||||
struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense>
|
||||
{
|
||||
typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>& type;
|
||||
};
|
||||
|
||||
template<typename CustomUnaryFunc, typename XprType>
|
||||
struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
|
||||
{
|
||||
typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type;
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
||||
|
||||
template<typename CustomUnaryFunc, typename XprType>
|
||||
class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors>
|
||||
{
|
||||
public:
|
||||
typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar;
|
||||
typedef typename internal::traits<TensorCustomUnaryOp>::Packet Packet;
|
||||
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
typedef typename internal::nested<TensorCustomUnaryOp>::type Nested;
|
||||
typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind;
|
||||
typedef typename internal::traits<TensorCustomUnaryOp>::Index Index;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func)
|
||||
: m_expr(expr), m_func(func) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const CustomUnaryFunc& func() const { return m_func; }
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const typename internal::remove_all<typename XprType::Nested>::type&
|
||||
expression() const { return m_expr; }
|
||||
|
||||
protected:
|
||||
typename XprType::Nested m_expr;
|
||||
const CustomUnaryFunc m_func;
|
||||
};
|
||||
|
||||
|
||||
// Eval as rvalue
|
||||
template<typename CustomUnaryFunc, typename XprType, typename Device>
|
||||
struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device>
|
||||
{
|
||||
typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> ArgType;
|
||||
typedef typename internal::traits<ArgType>::Index Index;
|
||||
static const int NumDims = internal::traits<ArgType>::NumDimensions;
|
||||
typedef DSizes<Index, NumDims> Dimensions;
|
||||
typedef
|
||||
typename internal::remove_const<typename ArgType::Scalar>::type Scalar;
|
||||
|
||||
enum {
|
||||
IsAligned = false,
|
||||
PacketAccess = (internal::packet_traits<Scalar>::size > 1),
|
||||
BlockAccess = false,
|
||||
Layout = TensorEvaluator<XprType, Device>::Layout,
|
||||
CoordAccess = false, // to be implemented
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device)
|
||||
: m_op(op), m_device(device), m_result(NULL)
|
||||
{
|
||||
m_dimensions = op.func().dimensions(op.expression());
|
||||
}
|
||||
|
||||
typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
|
||||
if (data) {
|
||||
evalTo(data);
|
||||
return false;
|
||||
} else {
|
||||
m_result = static_cast<CoeffReturnType*>(
|
||||
m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
|
||||
evalTo(m_result);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
|
||||
if (m_result != NULL) {
|
||||
m_device.deallocate(m_result);
|
||||
m_result = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
|
||||
return m_result[index];
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
|
||||
return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
|
||||
|
||||
protected:
|
||||
EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
|
||||
TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result(
|
||||
data, m_dimensions);
|
||||
m_op.func().eval(m_op.expression(), result, m_device);
|
||||
}
|
||||
|
||||
Dimensions m_dimensions;
|
||||
const ArgType m_op;
|
||||
const Device& m_device;
|
||||
CoeffReturnType* m_result;
|
||||
};
|
||||
|
||||
|
||||
|
||||
/** \class TensorCustomBinaryOp
|
||||
* \ingroup CXX11_Tensor_Module
|
||||
*
|
||||
* \brief Tensor custom class.
|
||||
*
|
||||
*
|
||||
*/
|
||||
namespace internal {
|
||||
template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
|
||||
struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
|
||||
{
|
||||
typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
|
||||
typename RhsXprType::Scalar>::ret Scalar;
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
|
||||
typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
|
||||
typedef typename internal::promote_storage_type<typename LhsXprType::PacketReturnType,
|
||||
typename RhsXprType::PacketReturnType>::ret PacketReturnType;
|
||||
typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
|
||||
typename traits<RhsXprType>::StorageKind>::ret StorageKind;
|
||||
typedef typename promote_index_type<typename traits<LhsXprType>::Index,
|
||||
typename traits<RhsXprType>::Index>::type Index;
|
||||
typedef typename LhsXprType::Nested LhsNested;
|
||||
typedef typename RhsXprType::Nested RhsNested;
|
||||
typedef typename remove_reference<LhsNested>::type _LhsNested;
|
||||
typedef typename remove_reference<RhsNested>::type _RhsNested;
|
||||
static const int NumDimensions = traits<LhsXprType>::NumDimensions;
|
||||
static const int Layout = traits<LhsXprType>::Layout;
|
||||
};
|
||||
|
||||
template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
|
||||
struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense>
|
||||
{
|
||||
typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type;
|
||||
};
|
||||
|
||||
template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
|
||||
struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
|
||||
{
|
||||
typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type;
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
||||
|
||||
template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
|
||||
class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors>
|
||||
{
|
||||
public:
|
||||
typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar;
|
||||
typedef typename internal::traits<TensorCustomBinaryOp>::Packet Packet;
|
||||
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||
typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType;
|
||||
typedef typename internal::traits<TensorCustomBinaryOp>::PacketReturnType PacketReturnType;
|
||||
typedef typename internal::nested<TensorCustomBinaryOp>::type Nested;
|
||||
typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind;
|
||||
typedef typename internal::traits<TensorCustomBinaryOp>::Index Index;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func)
|
||||
|
||||
: m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const CustomBinaryFunc& func() const { return m_func; }
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const typename internal::remove_all<typename LhsXprType::Nested>::type&
|
||||
lhsExpression() const { return m_lhs_xpr; }
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const typename internal::remove_all<typename RhsXprType::Nested>::type&
|
||||
rhsExpression() const { return m_rhs_xpr; }
|
||||
|
||||
protected:
|
||||
typename LhsXprType::Nested m_lhs_xpr;
|
||||
typename RhsXprType::Nested m_rhs_xpr;
|
||||
const CustomBinaryFunc m_func;
|
||||
};
|
||||
|
||||
|
||||
// Eval as rvalue
|
||||
template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device>
|
||||
struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device>
|
||||
{
|
||||
typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> XprType;
|
||||
typedef typename internal::traits<XprType>::Index Index;
|
||||
static const int NumDims = internal::traits<XprType>::NumDimensions;
|
||||
typedef DSizes<Index, NumDims> Dimensions;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
|
||||
enum {
|
||||
IsAligned = false,
|
||||
PacketAccess = (internal::packet_traits<Scalar>::size > 1),
|
||||
BlockAccess = false,
|
||||
Layout = TensorEvaluator<LhsXprType, Device>::Layout,
|
||||
CoordAccess = false, // to be implemented
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
||||
: m_op(op), m_device(device), m_result(NULL)
|
||||
{
|
||||
m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression());
|
||||
}
|
||||
|
||||
typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
|
||||
if (data) {
|
||||
evalTo(data);
|
||||
return false;
|
||||
} else {
|
||||
m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
|
||||
evalTo(m_result);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
|
||||
if (m_result != NULL) {
|
||||
m_device.deallocate(m_result);
|
||||
m_result = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
|
||||
return m_result[index];
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
|
||||
return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
|
||||
|
||||
protected:
|
||||
EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
|
||||
TensorMap<Tensor<Scalar, NumDims, Layout> > result(data, m_dimensions);
|
||||
m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device);
|
||||
}
|
||||
|
||||
Dimensions m_dimensions;
|
||||
const XprType m_op;
|
||||
const Device& m_device;
|
||||
CoeffReturnType* m_result;
|
||||
};
|
||||
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
|
@ -42,6 +42,9 @@ template<typename Strides, typename XprType> class TensorStridingOp;
|
||||
template<typename Generator, typename XprType> class TensorGeneratorOp;
|
||||
template<typename LeftXprType, typename RightXprType> class TensorAssignOp;
|
||||
|
||||
template<typename CustomUnaryFunc, typename XprType> class TensorCustomUnaryOp;
|
||||
template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> class TensorCustomBinaryOp;
|
||||
|
||||
template<typename XprType> class TensorEvalToOp;
|
||||
template<typename XprType> class TensorForcedEvalOp;
|
||||
|
||||
|
@ -156,9 +156,9 @@ struct eval<const TensorRef<PlainObjectType>, Eigen::Dense>
|
||||
};
|
||||
|
||||
// TODO nested<> does not exist anymore in Eigen/Core, and it thus has to be removed in favor of ref_selector.
|
||||
template<typename T, int n=1, typename PlainObject = void> struct nested
|
||||
{
|
||||
typedef typename ref_selector<T>::type type;
|
||||
template<typename T, int n=1, typename PlainObject = void> struct nested
|
||||
{
|
||||
typedef typename ref_selector<T>::type type;
|
||||
};
|
||||
|
||||
template <typename Scalar_, std::size_t NumIndices_, int Options_, typename IndexType_>
|
||||
|
@ -137,6 +137,7 @@ if(EIGEN_TEST_CXX11)
|
||||
ei_add_test(cxx11_tensor_layout_swap "-std=c++0x")
|
||||
ei_add_test(cxx11_tensor_io "-std=c++0x")
|
||||
ei_add_test(cxx11_tensor_generator "-std=c++0x")
|
||||
ei_add_test(cxx11_tensor_custom_op "-std=c++0x")
|
||||
|
||||
# These tests needs nvcc
|
||||
# ei_add_test(cxx11_tensor_device "-std=c++0x")
|
||||
|
107
unsupported/test/cxx11_tensor_custom_op.cpp
Normal file
107
unsupported/test/cxx11_tensor_custom_op.cpp
Normal file
@ -0,0 +1,107 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#include "main.h"
|
||||
|
||||
#include <Eigen/CXX11/Tensor>
|
||||
|
||||
using Eigen::Tensor;
|
||||
|
||||
|
||||
struct InsertZeros {
|
||||
DSizes<DenseIndex, 2> dimensions(const Tensor<float, 2>& input) const {
|
||||
DSizes<DenseIndex, 2> result;
|
||||
result[0] = input.dimension(0) * 2;
|
||||
result[1] = input.dimension(1) * 2;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Output, typename Device>
|
||||
void eval(const Tensor<float, 2>& input, Output& output, const Device& device) const
|
||||
{
|
||||
array<DenseIndex, 2> strides{{2, 2}};
|
||||
output.stride(strides).device(device) = input;
|
||||
|
||||
Eigen::DSizes<DenseIndex, 2> offsets(1,1);
|
||||
Eigen::DSizes<DenseIndex, 2> extents(output.dimension(0)-1, output.dimension(1)-1);
|
||||
output.slice(offsets, extents).stride(strides).device(device) = input.constant(0.0f);
|
||||
}
|
||||
};
|
||||
|
||||
static void test_custom_unary_op()
|
||||
{
|
||||
Tensor<float, 2> tensor(3,5);
|
||||
tensor.setRandom();
|
||||
|
||||
Tensor<float, 2> result = tensor.customOp(InsertZeros());
|
||||
VERIFY_IS_EQUAL(result.dimension(0), 6);
|
||||
VERIFY_IS_EQUAL(result.dimension(1), 10);
|
||||
|
||||
for (int i = 0; i < 6; i+=2) {
|
||||
for (int j = 0; j < 10; j+=2) {
|
||||
VERIFY_IS_EQUAL(result(i, j), tensor(i/2, j/2));
|
||||
}
|
||||
}
|
||||
for (int i = 1; i < 6; i+=2) {
|
||||
for (int j = 1; j < 10; j+=2) {
|
||||
VERIFY_IS_EQUAL(result(i, j), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
struct BatchMatMul {
|
||||
DSizes<DenseIndex, 3> dimensions(const Tensor<float, 3>& input1, const Tensor<float, 3>& input2) const {
|
||||
DSizes<DenseIndex, 3> result;
|
||||
result[0] = input1.dimension(0);
|
||||
result[1] = input2.dimension(1);
|
||||
result[2] = input2.dimension(2);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Output, typename Device>
|
||||
void eval(const Tensor<float, 3>& input1, const Tensor<float, 3>& input2,
|
||||
Output& output, const Device& device) const
|
||||
{
|
||||
typedef Tensor<float, 3>::DimensionPair DimPair;
|
||||
array<DimPair, 1> dims({{DimPair(1, 0)}});
|
||||
for (int i = 0; i < output.dimension(2); ++i) {
|
||||
output.template chip<2>(i).device(device) = input1.chip<2>(i).contract(input2.chip<2>(i), dims);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static void test_custom_binary_op()
|
||||
{
|
||||
Tensor<float, 3> tensor1(2,3,5);
|
||||
tensor1.setRandom();
|
||||
Tensor<float, 3> tensor2(3,7,5);
|
||||
tensor2.setRandom();
|
||||
|
||||
Tensor<float, 3> result = tensor1.customOp(tensor2, BatchMatMul());
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
typedef Tensor<float, 3>::DimensionPair DimPair;
|
||||
array<DimPair, 1> dims({{DimPair(1, 0)}});
|
||||
Tensor<float, 2> reference = tensor1.chip<2>(i).contract(tensor2.chip<2>(i), dims);
|
||||
TensorRef<Tensor<float, 2>> val = result.chip<2>(i);
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_APPROX(val(j, k), reference(j, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void test_cxx11_tensor_custom_op()
|
||||
{
|
||||
CALL_SUBTEST(test_custom_unary_op());
|
||||
CALL_SUBTEST(test_custom_binary_op());
|
||||
}
|
Loading…
Reference in New Issue
Block a user