mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Fixed the tensor contraction code.
This commit is contained in:
parent
b6f96cf7dd
commit
4181556907
@ -152,13 +152,7 @@ struct TensorContractionParams {
|
||||
// 1. Elementwise Relu transformation following Conv2D.
|
||||
// 2. AddBias to the Conv2D output channels dimension.
|
||||
//
|
||||
// See expected implementation in NoOpOutputKernel.
|
||||
struct OutputKernel {
|
||||
template <typename Index, typename Scalar>
|
||||
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
|
||||
};
|
||||
|
||||
// Output kernel that does absolutely nothing.
|
||||
// The NoOpOutputKernel implements an output kernel that does absolutely nothing.
|
||||
struct NoOpOutputKernel {
|
||||
/**
|
||||
* Tensor contraction evaluator calls this kernel after finishing each block
|
||||
@ -177,7 +171,7 @@ struct NoOpOutputKernel {
|
||||
*/
|
||||
template <typename Index, typename Scalar>
|
||||
EIGEN_ALWAYS_INLINE void operator()(
|
||||
const OutputKernel::OutputMapper<Index, Scalar>& /*output_mapper*/,
|
||||
const internal::blas_data_mapper<Scalar, Index, ColMajor>& /*output_mapper*/,
|
||||
const TensorContractionParams& /*params*/, Index /*i*/,
|
||||
Index /*j*/, Index /*num_rows*/, Index /*num_cols*/) const {}
|
||||
};
|
||||
|
@ -514,7 +514,7 @@ static void test_const_inputs()
|
||||
struct SqrtOutputKernel {
|
||||
template <typename Index, typename Scalar>
|
||||
EIGEN_ALWAYS_INLINE void operator()(
|
||||
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
|
||||
const internal::blas_data_mapper<Scalar, Index, ColMajor>& output_mapper,
|
||||
const TensorContractionParams&, Index, Index, Index num_rows,
|
||||
Index num_cols) const {
|
||||
for (int i = 0; i < num_rows; ++i) {
|
||||
|
@ -255,7 +255,7 @@ void test_multithread_contraction_agrees_with_singlethread() {
|
||||
struct SqrtOutputKernel {
|
||||
template <typename Index, typename Scalar>
|
||||
EIGEN_ALWAYS_INLINE void operator()(
|
||||
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
|
||||
const internal::blas_data_mapper<Scalar, Index, ColMajor>& output_mapper,
|
||||
const TensorContractionParams&, Index, Index, Index num_rows,
|
||||
Index num_cols) const {
|
||||
for (int i = 0; i < num_rows; ++i) {
|
||||
|
Loading…
Reference in New Issue
Block a user