mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Fixed the tensor contraction code.
This commit is contained in:
parent
b6f96cf7dd
commit
4181556907
@ -152,13 +152,7 @@ struct TensorContractionParams {
|
|||||||
// 1. Elementwise Relu transformation following Conv2D.
|
// 1. Elementwise Relu transformation following Conv2D.
|
||||||
// 2. AddBias to the Conv2D output channels dimension.
|
// 2. AddBias to the Conv2D output channels dimension.
|
||||||
//
|
//
|
||||||
// See expected implementation in NoOpOutputKernel.
|
// The NoOpOutputKernel implements an output kernel that does absolutely nothing.
|
||||||
struct OutputKernel {
|
|
||||||
template <typename Index, typename Scalar>
|
|
||||||
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Output kernel that does absolutely nothing.
|
|
||||||
struct NoOpOutputKernel {
|
struct NoOpOutputKernel {
|
||||||
/**
|
/**
|
||||||
* Tensor contraction evaluator calls this kernel after finishing each block
|
* Tensor contraction evaluator calls this kernel after finishing each block
|
||||||
@ -177,7 +171,7 @@ struct NoOpOutputKernel {
|
|||||||
*/
|
*/
|
||||||
template <typename Index, typename Scalar>
|
template <typename Index, typename Scalar>
|
||||||
EIGEN_ALWAYS_INLINE void operator()(
|
EIGEN_ALWAYS_INLINE void operator()(
|
||||||
const OutputKernel::OutputMapper<Index, Scalar>& /*output_mapper*/,
|
const internal::blas_data_mapper<Scalar, Index, ColMajor>& /*output_mapper*/,
|
||||||
const TensorContractionParams& /*params*/, Index /*i*/,
|
const TensorContractionParams& /*params*/, Index /*i*/,
|
||||||
Index /*j*/, Index /*num_rows*/, Index /*num_cols*/) const {}
|
Index /*j*/, Index /*num_rows*/, Index /*num_cols*/) const {}
|
||||||
};
|
};
|
||||||
|
@ -514,7 +514,7 @@ static void test_const_inputs()
|
|||||||
struct SqrtOutputKernel {
|
struct SqrtOutputKernel {
|
||||||
template <typename Index, typename Scalar>
|
template <typename Index, typename Scalar>
|
||||||
EIGEN_ALWAYS_INLINE void operator()(
|
EIGEN_ALWAYS_INLINE void operator()(
|
||||||
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
|
const internal::blas_data_mapper<Scalar, Index, ColMajor>& output_mapper,
|
||||||
const TensorContractionParams&, Index, Index, Index num_rows,
|
const TensorContractionParams&, Index, Index, Index num_rows,
|
||||||
Index num_cols) const {
|
Index num_cols) const {
|
||||||
for (int i = 0; i < num_rows; ++i) {
|
for (int i = 0; i < num_rows; ++i) {
|
||||||
|
@ -255,7 +255,7 @@ void test_multithread_contraction_agrees_with_singlethread() {
|
|||||||
struct SqrtOutputKernel {
|
struct SqrtOutputKernel {
|
||||||
template <typename Index, typename Scalar>
|
template <typename Index, typename Scalar>
|
||||||
EIGEN_ALWAYS_INLINE void operator()(
|
EIGEN_ALWAYS_INLINE void operator()(
|
||||||
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
|
const internal::blas_data_mapper<Scalar, Index, ColMajor>& output_mapper,
|
||||||
const TensorContractionParams&, Index, Index, Index num_rows,
|
const TensorContractionParams&, Index, Index, Index num_rows,
|
||||||
Index num_cols) const {
|
Index num_cols) const {
|
||||||
for (int i = 0; i < num_rows; ++i) {
|
for (int i = 0; i < num_rows; ++i) {
|
||||||
|
Loading…
Reference in New Issue
Block a user