From d26e19714fca2faf544619c3604f88d980e5a207 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Thu, 28 Mar 2024 04:26:55 +0000 Subject: [PATCH] Add missing cwiseSquare, tests for cwise matrix ops. --- Eigen/src/plugins/MatrixCwiseUnaryOps.inc | 11 +- test/CMakeLists.txt | 1 + test/matrix_cwise.cpp | 302 ++++++++++++++++++++++ 3 files changed, 313 insertions(+), 1 deletion(-) create mode 100644 test/matrix_cwise.cpp diff --git a/Eigen/src/plugins/MatrixCwiseUnaryOps.inc b/Eigen/src/plugins/MatrixCwiseUnaryOps.inc index b23f4a5ae..325b0fbe0 100644 --- a/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +++ b/Eigen/src/plugins/MatrixCwiseUnaryOps.inc @@ -17,6 +17,7 @@ typedef CwiseUnaryOp, const Derived> CwiseArgRet typedef CwiseUnaryOp, const Derived> CwiseCArgReturnType; typedef CwiseUnaryOp, const Derived> CwiseSqrtReturnType; typedef CwiseUnaryOp, const Derived> CwiseCbrtReturnType; +typedef CwiseUnaryOp, const Derived> CwiseSquareReturnType; typedef CwiseUnaryOp, const Derived> CwiseSignReturnType; typedef CwiseUnaryOp, const Derived> CwiseInverseReturnType; @@ -66,7 +67,15 @@ EIGEN_DOC_UNARY_ADDONS(cwiseCbrt, cube - root) /// /// \sa cwiseSqrt(), cwiseSquare(), cwisePow() /// -EIGEN_DEVICE_FUNC inline const CwiseCbrtReturnType cwiseCbrt() const { return CwiseSCbrtReturnType(derived()); } +EIGEN_DEVICE_FUNC inline const CwiseCbrtReturnType cwiseCbrt() const { return CwiseCbrtReturnType(derived()); } + +/// \returns an expression of the coefficient-wise square of *this. +/// +EIGEN_DOC_UNARY_ADDONS(cwiseSquare, square) +/// +/// \sa cwisePow(), cwiseSqrt(), cwiseCbrt() +/// +EIGEN_DEVICE_FUNC inline const CwiseSquareReturnType cwiseSquare() const { return CwiseSquareReturnType(derived()); } /// \returns an expression of the coefficient-wise signum of *this. /// diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a778d1e75..469258435 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -221,6 +221,7 @@ ei_add_test(mapstride) ei_add_test(unaryview) ei_add_test(mapstaticmethods) ei_add_test(array_cwise) +ei_add_test(matrix_cwise) ei_add_test(array_for_matrix) ei_add_test(array_replicate) ei_add_test(array_reverse) diff --git a/test/matrix_cwise.cpp b/test/matrix_cwise.cpp new file mode 100644 index 000000000..56cd2d6f0 --- /dev/null +++ b/test/matrix_cwise.cpp @@ -0,0 +1,302 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2008-2009 Gael Guennebaud +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include +#include "main.h" + +template +struct matrix_of { + using type = MatrixType; +}; + +template +struct matrix_of, NewScalar> { + using type = Eigen::Matrix; +}; + +// Unary function reference. +template ::type>::type> +OutMatrixType cwise_ref(const MatrixType& m, Func f = Func()) { + OutMatrixType out(m.rows(), m.cols()); + for (Eigen::Index r = 0; r < m.rows(); ++r) { + for (Eigen::Index c = 0; c < m.cols(); ++c) { + out(r, c) = f(m(r, c)); + } + } + return out; +} + +// Binary function reference. +template ::type>::type> +OutMatrixType cwise_ref(const MatrixType& m1, const MatrixType& m2, Func f = Func()) { + OutMatrixType out(m1.rows(), m1.cols()); + for (Eigen::Index r = 0; r < m1.rows(); ++r) { + for (Eigen::Index c = 0; c < m1.cols(); ++c) { + out(r, c) = f(m1(r, c), m2(r, c)); + } + } + return out; +} + +template +void test_cwise_real(const MatrixType& m) { + using Scalar = typename MatrixType::Scalar; + Index rows = m.rows(); + Index cols = m.cols(); + MatrixType m1 = MatrixType::Random(rows, cols); + MatrixType m2, m3, m4; + + // Supported unary ops. + VERIFY_IS_CWISE_APPROX(m1.cwiseAbs(), cwise_ref(m1, [](const Scalar& x) { return Eigen::numext::abs(x); })); + VERIFY_IS_CWISE_APPROX(m1.cwiseSign(), cwise_ref(m1, [](const Scalar& x) { return Eigen::numext::sign(x); })); + VERIFY_IS_CWISE_APPROX(m1.cwiseCbrt(), cwise_ref(m1, [](const Scalar& x) { return Eigen::numext::cbrt(x); })); + // For integers, avoid division by zero. + m2 = m1; + if (Eigen::NumTraits::IsInteger) { + m2 = m1.unaryExpr([](const Scalar& x) { return Eigen::numext::equal_strict(x, Scalar(0)) ? Scalar(1) : x; }); + } + VERIFY_IS_CWISE_APPROX(m2.cwiseInverse(), cwise_ref(m2, [](const Scalar& x) { return Scalar(Scalar(1) / x); })); + VERIFY_IS_CWISE_APPROX(m1.cwiseArg(), cwise_ref(m1, [](const Scalar& x) { return Eigen::numext::arg(x); })); + // Only take sqrt of positive values. + m2 = m1.cwiseAbs(); + VERIFY_IS_CWISE_APPROX(m2.cwiseSqrt(), cwise_ref(m2, [](const Scalar& x) { return Eigen::numext::sqrt(x); })); + // Only find Square/Abs2 of +/- sqrt values so we don't overflow. + m2 = m2.cwiseSqrt().array() * m1.cwiseSign().array(); + VERIFY_IS_CWISE_APPROX(m2.cwiseAbs2(), cwise_ref(m2, [](const Scalar& x) { return Eigen::numext::abs2(x); })); + VERIFY_IS_CWISE_APPROX(m2.cwiseSquare(), cwise_ref(m2, [](const Scalar& x) { return Scalar(x * x); })); + VERIFY_IS_CWISE_APPROX(m2.cwisePow(Scalar(2)), + cwise_ref(m2, [](const Scalar& x) { return Eigen::numext::pow(x, Scalar(2)); })); + + // Supported binary ops. + m1.setRandom(rows, cols); + m2.setRandom(rows, cols); + VERIFY_IS_CWISE_EQUAL(m1.cwiseMin(m2), + cwise_ref(m1, m2, [](const Scalar& x, const Scalar& y) { return Eigen::numext::mini(x, y); })); + VERIFY_IS_CWISE_EQUAL(m1.template cwiseMin(m2), + cwise_ref(m1, m2, [](const Scalar& x, const Scalar& y) { return Eigen::numext::mini(x, y); })); + VERIFY_IS_CWISE_EQUAL(m1.template cwiseMin(m2), + cwise_ref(m1, m2, [](const Scalar& x, const Scalar& y) { return Eigen::numext::mini(x, y); })); + VERIFY_IS_CWISE_EQUAL(m1.template cwiseMin(m2), + cwise_ref(m1, m2, [](const Scalar& x, const Scalar& y) { return Eigen::numext::mini(x, y); })); + VERIFY_IS_CWISE_EQUAL(m1.cwiseMax(m2), + cwise_ref(m1, m2, [](const Scalar& x, const Scalar& y) { return Eigen::numext::maxi(x, y); })); + VERIFY_IS_CWISE_EQUAL(m1.template cwiseMax(m2), + cwise_ref(m1, m2, [](const Scalar& x, const Scalar& y) { return Eigen::numext::maxi(x, y); })); + VERIFY_IS_CWISE_EQUAL(m1.template cwiseMax(m2), + cwise_ref(m1, m2, [](const Scalar& x, const Scalar& y) { return Eigen::numext::maxi(x, y); })); + VERIFY_IS_CWISE_EQUAL(m1.template cwiseMax(m2), + cwise_ref(m1, m2, [](const Scalar& x, const Scalar& y) { return Eigen::numext::maxi(x, y); })); + // Scalar comparison. + Scalar mean = Eigen::NumTraits::highest() / Scalar(2) + Eigen::NumTraits::lowest() / Scalar(2); + m4.setConstant(rows, cols, mean); + VERIFY_IS_CWISE_EQUAL(m1.cwiseMin(mean), + cwise_ref(m1, m4, [](const Scalar& x, const Scalar& y) { return Eigen::numext::mini(x, y); })); + VERIFY_IS_CWISE_EQUAL(m1.template cwiseMin(mean), + cwise_ref(m1, m4, [](const Scalar& x, const Scalar& y) { return Eigen::numext::mini(x, y); })); + VERIFY_IS_CWISE_EQUAL(m1.template cwiseMin(mean), + cwise_ref(m1, m4, [](const Scalar& x, const Scalar& y) { return Eigen::numext::mini(x, y); })); + VERIFY_IS_CWISE_EQUAL(m1.template cwiseMin(mean), + cwise_ref(m1, m4, [](const Scalar& x, const Scalar& y) { return Eigen::numext::mini(x, y); })); + VERIFY_IS_CWISE_EQUAL(m1.cwiseMax(mean), + cwise_ref(m1, m4, [](const Scalar& x, const Scalar& y) { return Eigen::numext::maxi(x, y); })); + VERIFY_IS_CWISE_EQUAL(m1.template cwiseMax(mean), + cwise_ref(m1, m4, [](const Scalar& x, const Scalar& y) { return Eigen::numext::maxi(x, y); })); + VERIFY_IS_CWISE_EQUAL(m1.template cwiseMax(mean), + cwise_ref(m1, m4, [](const Scalar& x, const Scalar& y) { return Eigen::numext::maxi(x, y); })); + VERIFY_IS_CWISE_EQUAL(m1.template cwiseMax(mean), + cwise_ref(m1, m4, [](const Scalar& x, const Scalar& y) { return Eigen::numext::maxi(x, y); })); + // For products, avoid integer overflow by limiting the input < sqrt(max). + m3 = m1; + m4 = m2; + if (Eigen::NumTraits::IsInteger) { + const Scalar kMax = Eigen::numext::sqrt(Eigen::NumTraits::highest()); + m3 = m1 - ((m1 / kMax) * kMax); + m4 = m2 - ((m2 / kMax) * kMax); + } + VERIFY_IS_CWISE_APPROX(m3.cwiseProduct(m4), + cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { return Scalar(x * y); })); + // For quotients involving integers, avoid division by zero. + m4 = m2; + if (Eigen::NumTraits::IsInteger) { + m4 = m2.unaryExpr([](const Scalar& x) { return Eigen::numext::equal_strict(x, Scalar(0)) ? Scalar(1) : x; }); + } + VERIFY_IS_CWISE_APPROX(m1.cwiseQuotient(m4), + cwise_ref(m1, m4, [](const Scalar& x, const Scalar& y) { return Scalar(x / y); })); + // For equality comparisons, limit range to increase number of equalities. + if (Eigen::NumTraits::IsInteger) { + const Scalar kMax = Scalar(10); + m3 = m1 - ((m1 / kMax) * kMax); + m4 = m2 - ((m2 / kMax) * kMax); + mean = Eigen::NumTraits::IsSigned ? Scalar(0) : kMax / Scalar(2); + } else { + const Scalar kShift = Scalar(10); + m3 = (m1 * kShift).array().floor() / kShift; + m4 = (m2 * kShift).array().floor() / kShift; + mean = Scalar(0); + } + VERIFY_IS_CWISE_EQUAL(m3.cwiseEqual(m4), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return Eigen::numext::equal_strict(x, y); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseNotEqual(m4), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return !Eigen::numext::equal_strict(x, y); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseLess(m4), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { return x < y; })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseGreater(m4), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { return x > y; })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseLessOrEqual(m4), + cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { return x <= y; })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseGreaterOrEqual(m4), + cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { return x >= y; })); + // Typed-Equality. + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedEqual(m4), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return Eigen::numext::equal_strict(x, y) ? Scalar(1) : Scalar(0); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedNotEqual(m4), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return !Eigen::numext::equal_strict(x, y) ? Scalar(1) : Scalar(0); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedLess(m4), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return x < y ? Scalar(1) : Scalar(0); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedGreater(m4), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return x > y ? Scalar(1) : Scalar(0); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedLessOrEqual(m4), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return x <= y ? Scalar(1) : Scalar(0); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedGreaterOrEqual(m4), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return x >= y ? Scalar(1) : Scalar(0); + })); + // Scalar. + m4.setConstant(rows, cols, mean); + VERIFY_IS_CWISE_EQUAL(m3.cwiseEqual(mean), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return Eigen::numext::equal_strict(x, y); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseNotEqual(mean), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return !Eigen::numext::equal_strict(x, y); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseLess(mean), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { return x < y; })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseGreater(mean), + cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { return x > y; })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseLessOrEqual(mean), + cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { return x <= y; })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseGreaterOrEqual(mean), + cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { return x >= y; })); + // Typed. + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedEqual(mean), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return Eigen::numext::equal_strict(x, y) ? Scalar(1) : Scalar(0); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedNotEqual(mean), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return !Eigen::numext::equal_strict(x, y) ? Scalar(1) : Scalar(0); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedLess(mean), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return x < y ? Scalar(1) : Scalar(0); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedGreater(mean), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return x > y ? Scalar(1) : Scalar(0); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedLessOrEqual(mean), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return x <= y ? Scalar(1) : Scalar(0); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedGreaterOrEqual(mean), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return x >= y ? Scalar(1) : Scalar(0); + })); +} + +template +void test_cwise_complex(const MatrixType& m) { + using Scalar = typename MatrixType::Scalar; + using RealScalar = typename NumTraits::Real; + Index rows = m.rows(); + Index cols = m.cols(); + MatrixType m1 = MatrixType::Random(rows, cols); + MatrixType m2, m3, m4; + + // Supported unary ops. + VERIFY_IS_CWISE_APPROX(m1.cwiseAbs(), cwise_ref(m1, [](const Scalar& x) { return Eigen::numext::abs(x); })); + VERIFY_IS_CWISE_APPROX(m1.cwiseSqrt(), cwise_ref(m1, [](const Scalar& x) { return Eigen::numext::sqrt(x); })); + VERIFY_IS_CWISE_APPROX(m1.cwiseInverse(), cwise_ref(m1, [](const Scalar& x) { return Scalar(Scalar(1) / x); })); + VERIFY_IS_CWISE_APPROX(m1.cwiseArg(), cwise_ref(m1, [](const Scalar& x) { return Eigen::numext::arg(x); })); + VERIFY_IS_CWISE_APPROX(m1.cwiseCArg(), cwise_ref(m1, [](const Scalar& x) { return Scalar(Eigen::numext::arg(x)); })); + // Only find Square/Abs2 of +/- sqrt values so we don't overflow. + m2 = m1.cwiseSqrt().array() * m1.cwiseSign().array(); + VERIFY_IS_CWISE_APPROX(m2.cwiseAbs2(), cwise_ref(m2, [](const Scalar& x) { return Eigen::numext::abs2(x); })); + VERIFY_IS_CWISE_APPROX(m2.cwiseSquare(), cwise_ref(m2, [](const Scalar& x) { return Scalar(x * x); })); + VERIFY_IS_CWISE_APPROX(m2.cwisePow(Scalar(2)), + cwise_ref(m2, [](const Scalar& x) { return Eigen::numext::pow(x, Scalar(2)); })); + + // Supported binary ops. + m1.setRandom(rows, cols); + m2.setRandom(rows, cols); + VERIFY_IS_CWISE_APPROX(m1.cwiseProduct(m2), + cwise_ref(m1, m2, [](const Scalar& x, const Scalar& y) { return Scalar(x * y); })); + VERIFY_IS_CWISE_APPROX(m1.cwiseQuotient(m2), + cwise_ref(m1, m2, [](const Scalar& x, const Scalar& y) { return Scalar(x / y); })); + // For equality comparisons, limit range to increase number of equalities. + { + const RealScalar kShift = RealScalar(10); + m3 = m1; + m4 = m2; + m3.real() = (m1.real() * kShift).array().floor() / kShift; + m3.imag() = (m1.imag() * kShift).array().floor() / kShift; + m4.real() = (m2.real() * kShift).array().floor() / kShift; + m4.imag() = (m2.imag() * kShift).array().floor() / kShift; + } + VERIFY_IS_CWISE_EQUAL(m3.cwiseEqual(m4), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return Eigen::numext::equal_strict(x, y); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseNotEqual(m4), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return !Eigen::numext::equal_strict(x, y); + })); + // Typed-Equality. + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedEqual(m4), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return Eigen::numext::equal_strict(x, y) ? Scalar(1) : Scalar(0); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedNotEqual(m4), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return !Eigen::numext::equal_strict(x, y) ? Scalar(1) : Scalar(0); + })); + // Scalar. + Scalar mean = Scalar(0); + m4.setConstant(rows, cols, mean); + VERIFY_IS_CWISE_EQUAL(m3.cwiseEqual(mean), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return Eigen::numext::equal_strict(x, y); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseNotEqual(mean), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return !Eigen::numext::equal_strict(x, y); + })); + // Typed. + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedEqual(mean), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return Eigen::numext::equal_strict(x, y) ? Scalar(1) : Scalar(0); + })); + VERIFY_IS_CWISE_EQUAL(m3.cwiseTypedNotEqual(mean), cwise_ref(m3, m4, [](const Scalar& x, const Scalar& y) { + return !Eigen::numext::equal_strict(x, y) ? Scalar(1) : Scalar(0); + })); +} + +EIGEN_DECLARE_TEST(matrix_cwise) { + for (int i = 0; i < g_repeat; i++) { + CALL_SUBTEST_1(test_cwise_real(Eigen::Matrix(20, 20))); + CALL_SUBTEST_1(test_cwise_real(Eigen::Matrix(20, 20))); + CALL_SUBTEST_1(test_cwise_real(Eigen::Matrix(20, 20))); + CALL_SUBTEST_1(test_cwise_real(Eigen::Matrix(20, 20))); + CALL_SUBTEST_2(test_cwise_complex(Eigen::Matrix, Eigen::Dynamic, Eigen::Dynamic>(20, 20))); + CALL_SUBTEST_2(test_cwise_complex(Eigen::Matrix, Eigen::Dynamic, Eigen::Dynamic>(20, 20))); + CALL_SUBTEST_3(test_cwise_real(Eigen::Matrix(20, 20))); + CALL_SUBTEST_3(test_cwise_real(Eigen::Matrix(20, 20))); + CALL_SUBTEST_3(test_cwise_real(Eigen::Matrix(20, 20))); + CALL_SUBTEST_3(test_cwise_real(Eigen::Matrix(20, 20))); + CALL_SUBTEST_4(test_cwise_real(Eigen::Matrix(20, 20))); + CALL_SUBTEST_4(test_cwise_real(Eigen::Matrix(20, 20))); + CALL_SUBTEST_4(test_cwise_real(Eigen::Matrix(20, 20))); + CALL_SUBTEST_4(test_cwise_real(Eigen::Matrix(20, 20))); + } +}