mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-02-17 18:09:55 +08:00
optimize setConstant, setZero
This commit is contained in:
parent
5610a13b77
commit
8ad4344ca7
@ -318,6 +318,7 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/PlainObjectBase.h"
|
||||
#include "src/Core/Matrix.h"
|
||||
#include "src/Core/Array.h"
|
||||
#include "src/Core/Fill.h"
|
||||
#include "src/Core/CwiseTernaryOp.h"
|
||||
#include "src/Core/CwiseBinaryOp.h"
|
||||
#include "src/Core/CwiseUnaryOp.h"
|
||||
|
@ -737,20 +737,6 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR void call_dense_assignment
|
||||
dense_assignment_loop<Kernel>::run(kernel);
|
||||
}
|
||||
|
||||
// Specialization for filling the destination with a constant value.
|
||||
#if !EIGEN_COMP_MSVC
|
||||
#ifndef EIGEN_GPU_COMPILE_PHASE
|
||||
template <typename DstXprType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void call_dense_assignment_loop(
|
||||
DstXprType& dst,
|
||||
const Eigen::CwiseNullaryOp<Eigen::internal::scalar_constant_op<typename DstXprType::Scalar>, DstXprType>& src,
|
||||
const internal::assign_op<typename DstXprType::Scalar, typename DstXprType::Scalar>& func) {
|
||||
resize_if_allowed(dst, src, func);
|
||||
std::fill_n(dst.data(), dst.size(), src.functor()());
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename DstXprType, typename SrcXprType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void call_dense_assignment_loop(DstXprType& dst, const SrcXprType& src) {
|
||||
call_dense_assignment_loop(dst, src, internal::assign_op<typename DstXprType::Scalar, typename SrcXprType::Scalar>());
|
||||
|
@ -343,7 +343,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void DenseBase<Derived>::fill(const Scalar
|
||||
*/
|
||||
template <typename Derived>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase<Derived>::setConstant(const Scalar& val) {
|
||||
return derived() = Constant(rows(), cols(), val);
|
||||
internal::eigen_fill_impl<Derived>::run(derived(), val);
|
||||
return derived();
|
||||
}
|
||||
|
||||
/** Resizes to the given \a size, and sets all coefficients in this expression to the given value \a val.
|
||||
@ -547,7 +548,8 @@ EIGEN_DEVICE_FUNC bool DenseBase<Derived>::isZero(const RealScalar& prec) const
|
||||
*/
|
||||
template <typename Derived>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase<Derived>::setZero() {
|
||||
return setConstant(Scalar(0));
|
||||
internal::eigen_zero_impl<Derived>::run(derived());
|
||||
return derived();
|
||||
}
|
||||
|
||||
/** Resizes to the given \a size, and sets all coefficients in this expression to zero.
|
||||
@ -562,7 +564,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase<Derived>::setZero() {
|
||||
template <typename Derived>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& PlainObjectBase<Derived>::setZero(Index newSize) {
|
||||
resize(newSize);
|
||||
return setConstant(Scalar(0));
|
||||
return setZero();
|
||||
}
|
||||
|
||||
/** Resizes to the given size, and sets all coefficients in this expression to zero.
|
||||
@ -578,7 +580,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& PlainObjectBase<Derived>::setZero
|
||||
template <typename Derived>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& PlainObjectBase<Derived>::setZero(Index rows, Index cols) {
|
||||
resize(rows, cols);
|
||||
return setConstant(Scalar(0));
|
||||
return setZero();
|
||||
}
|
||||
|
||||
/** Resizes to the given size, changing only the number of columns, and sets all
|
||||
|
94
Eigen/src/Core/Fill.h
Normal file
94
Eigen/src/Core/Fill.h
Normal file
@ -0,0 +1,94 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2024 Charles Schlosser <cs.schlosser@gmail.com>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_FILL_H
|
||||
#define EIGEN_FILL_H
|
||||
|
||||
// IWYU pragma: private
|
||||
#include "./InternalHeaderCheck.h"
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <typename Xpr>
|
||||
struct eigen_fill_helper : std::false_type {};
|
||||
|
||||
template <typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
|
||||
struct eigen_fill_helper<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols>> : std::true_type {};
|
||||
|
||||
template <typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
|
||||
struct eigen_fill_helper<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols>> : std::true_type {};
|
||||
|
||||
template <typename Xpr, int BlockRows, int BlockCols>
|
||||
struct eigen_fill_helper<Block<Xpr, BlockRows, BlockCols, /*InnerPanel*/ true>> : eigen_fill_helper<Xpr> {};
|
||||
|
||||
template <typename Xpr, int BlockRows, int BlockCols>
|
||||
struct eigen_fill_helper<Block<Xpr, BlockRows, BlockCols, /*InnerPanel*/ false>>
|
||||
: std::integral_constant<bool, eigen_fill_helper<Xpr>::value &&
|
||||
(Xpr::IsRowMajor ? (BlockRows == 1) : (BlockCols == 1))> {};
|
||||
|
||||
template <typename Xpr, int Options, typename StrideType>
|
||||
struct eigen_fill_helper<Map<Xpr, Options, StrideType>>
|
||||
: std::integral_constant<bool, eigen_fill_helper<Xpr>::value &&
|
||||
(evaluator<Map<Xpr, Options, StrideType>>::Flags & LinearAccessBit)> {};
|
||||
|
||||
template <typename Xpr, bool use_fill = eigen_fill_helper<Xpr>::value>
|
||||
struct eigen_fill_impl {
|
||||
using Scalar = typename Xpr::Scalar;
|
||||
using Func = scalar_constant_op<Scalar>;
|
||||
using PlainObject = typename Xpr::PlainObject;
|
||||
using Constant = CwiseNullaryOp<Func, PlainObject>;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Xpr& dst, const Scalar& val) {
|
||||
dst = Constant(dst.rows(), dst.cols(), Func(val));
|
||||
}
|
||||
};
|
||||
|
||||
#if !EIGEN_COMP_MSVC
|
||||
#ifndef EIGEN_GPU_COMPILE_PHASE
|
||||
template <typename Xpr>
|
||||
struct eigen_fill_impl<Xpr, /*use_fill*/ true> {
|
||||
using Scalar = typename Xpr::Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Xpr& dst, const Scalar& val) {
|
||||
EIGEN_USING_STD(fill_n);
|
||||
fill_n(dst.data(), dst.size(), val);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename Xpr>
|
||||
struct eigen_memset_helper {
|
||||
static constexpr bool value = std::is_trivial<typename Xpr::Scalar>::value && eigen_fill_helper<Xpr>::value;
|
||||
};
|
||||
|
||||
template <typename Xpr, bool use_memset = eigen_memset_helper<Xpr>::value>
|
||||
struct eigen_zero_impl {
|
||||
using Scalar = typename Xpr::Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Xpr& dst) { eigen_fill_impl<Xpr, false>::run(dst, Scalar(0)); }
|
||||
};
|
||||
|
||||
template <typename Xpr>
|
||||
struct eigen_zero_impl<Xpr, /*use_memset*/ true> {
|
||||
using Scalar = typename Xpr::Scalar;
|
||||
static constexpr size_t max_bytes = (std::numeric_limits<std::ptrdiff_t>::max)();
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Xpr& dst) {
|
||||
const size_t num_bytes = dst.size() * sizeof(Scalar);
|
||||
#ifndef EIGEN_NO_DEBUG
|
||||
if (num_bytes > max_bytes) throw_std_bad_alloc();
|
||||
#endif
|
||||
EIGEN_USING_STD(memset);
|
||||
memset(dst.data(), 0, num_bytes);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
#endif // EIGEN_FILL_H
|
@ -73,6 +73,21 @@ void block(const MatrixType& m) {
|
||||
|
||||
block_real_only(m1, r1, r2, c1, c1, s1);
|
||||
|
||||
// test fill logic with innerpanel and non-innerpanel blocks
|
||||
m1.row(r1).setConstant(s1);
|
||||
VERIFY_IS_CWISE_EQUAL(m1.row(r1), DynamicVectorType::Constant(cols, s1).transpose());
|
||||
m1 = m1_copy;
|
||||
m1.col(c1).setConstant(s1);
|
||||
VERIFY_IS_CWISE_EQUAL(m1.col(c1), DynamicVectorType::Constant(rows, s1));
|
||||
m1 = m1_copy;
|
||||
// test setZero logic with innerpanel and non-innerpanel blocks
|
||||
m1.row(r1).setZero();
|
||||
VERIFY_IS_CWISE_EQUAL(m1.row(r1), DynamicVectorType::Zero(cols).transpose());
|
||||
m1 = m1_copy;
|
||||
m1.col(c1).setZero();
|
||||
VERIFY_IS_CWISE_EQUAL(m1.col(c1), DynamicVectorType::Zero(rows));
|
||||
m1 = m1_copy;
|
||||
|
||||
// check row() and col()
|
||||
VERIFY_IS_EQUAL(m1.col(c1).transpose(), m1.transpose().row(c1));
|
||||
// check operator(), both constant and non-constant, on row() and col()
|
||||
|
@ -75,8 +75,23 @@ void map_class_matrix(const MatrixType& m) {
|
||||
Map<MatrixType> map4(array4, rows, cols);
|
||||
|
||||
VERIFY_IS_EQUAL(map1, MatrixType::Ones(rows, cols));
|
||||
map1.setConstant(s1);
|
||||
VERIFY_IS_EQUAL(map1, MatrixType::Constant(rows, cols, s1));
|
||||
map1.setZero();
|
||||
VERIFY_IS_EQUAL(map1, MatrixType::Zero(rows, cols));
|
||||
|
||||
VERIFY_IS_EQUAL(map2, MatrixType::Ones(rows, cols));
|
||||
map2.setConstant(s1);
|
||||
VERIFY_IS_EQUAL(map2, MatrixType::Constant(rows, cols, s1));
|
||||
map2.setZero();
|
||||
VERIFY_IS_EQUAL(map2, MatrixType::Zero(rows, cols));
|
||||
|
||||
VERIFY_IS_EQUAL(map3, MatrixType::Ones(rows, cols));
|
||||
map3.setConstant(s1);
|
||||
VERIFY_IS_EQUAL(map3, MatrixType::Constant(rows, cols, s1));
|
||||
map3.setZero();
|
||||
VERIFY_IS_EQUAL(map3, MatrixType::Zero(rows, cols));
|
||||
|
||||
map1 = MatrixType::Random(rows, cols);
|
||||
map2 = map1;
|
||||
map3 = map1;
|
||||
|
@ -94,6 +94,8 @@ void map_class_matrix(const MatrixType& _m) {
|
||||
VERIFY_IS_APPROX(s1 * map, s1 * m);
|
||||
map *= s1;
|
||||
VERIFY_IS_APPROX(map, s1 * m);
|
||||
map.setZero();
|
||||
VERIFY_IS_CWISE_EQUAL(map, MatrixType::Zero(rows, cols));
|
||||
}
|
||||
|
||||
// test no inner stride and an outer stride of +4. This is quite important as for fixed-size matrices,
|
||||
@ -118,6 +120,8 @@ void map_class_matrix(const MatrixType& _m) {
|
||||
VERIFY_IS_APPROX(s1 * map, s1 * m);
|
||||
map *= s1;
|
||||
VERIFY_IS_APPROX(map, s1 * m);
|
||||
map.setZero();
|
||||
VERIFY_IS_CWISE_EQUAL(map, MatrixType::Zero(rows, cols));
|
||||
}
|
||||
|
||||
// test both inner stride and outer stride
|
||||
@ -138,6 +142,8 @@ void map_class_matrix(const MatrixType& _m) {
|
||||
VERIFY_IS_APPROX(s1 * map, s1 * m);
|
||||
map *= s1;
|
||||
VERIFY_IS_APPROX(map, s1 * m);
|
||||
map.setZero();
|
||||
VERIFY_IS_CWISE_EQUAL(map, MatrixType::Zero(rows, cols));
|
||||
}
|
||||
|
||||
// test inner stride and no outer stride
|
||||
@ -156,6 +162,8 @@ void map_class_matrix(const MatrixType& _m) {
|
||||
VERIFY_IS_APPROX(s1 * map, s1 * m);
|
||||
map *= s1;
|
||||
VERIFY_IS_APPROX(map, s1 * m);
|
||||
map.setZero();
|
||||
VERIFY_IS_CWISE_EQUAL(map, MatrixType::Zero(rows, cols));
|
||||
}
|
||||
|
||||
// test negative strides
|
||||
|
Loading…
Reference in New Issue
Block a user