* added cwise comparisons

* added "all" and "any" special redux operators
 * added support bool matrices
 * added support for cost model of STL functors via ei_functor_traits
  (By default ei_functor_traits query the functor member Cost)
This commit is contained in:
Gael Guennebaud 2008-04-03 18:13:27 +00:00
parent 249dc4f482
commit 048910caae
8 changed files with 282 additions and 11 deletions

View File

@ -65,4 +65,42 @@ template<typename Scalar> struct ei_scalar_max_op EIGEN_EMPTY_STRUCT {
enum { Cost = ConditionalJumpCost + NumTraits<Scalar>::AddCost };
};
// default ei_functor_traits for STL functors:
template<typename Result, typename Arg0, typename Arg1>
struct ei_functor_traits<std::binary_function<Result,Arg0,Arg1> >
{ enum { Cost = 10 }; };
template<typename Result, typename Arg0>
struct ei_functor_traits<std::unary_function<Result,Arg0> >
{ enum { Cost = 5 }; };
template<typename T>
struct ei_functor_traits<std::binder2nd<T> >
{ enum { Cost = 5 }; };
template<typename T>
struct ei_functor_traits<std::binder1st<T> >
{ enum { Cost = 5 }; };
template<typename T>
struct ei_functor_traits<std::greater<T> >
{ enum { Cost = 1 }; };
template<typename T>
struct ei_functor_traits<std::less<T> >
{ enum { Cost = 1 }; };
template<typename T>
struct ei_functor_traits<std::greater_equal<T> >
{ enum { Cost = 1 }; };
template<typename T>
struct ei_functor_traits<std::less_equal<T> >
{ enum { Cost = 1 }; };
template<typename T>
struct ei_functor_traits<std::equal_to<T> >
{ enum { Cost = 1 }; };
#endif // EIGEN_ASSOCIATIVE_FUNCTORS_H

View File

@ -61,7 +61,7 @@ struct ei_traits<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
MaxRowsAtCompileTime = Lhs::MaxRowsAtCompileTime,
MaxColsAtCompileTime = Lhs::MaxColsAtCompileTime,
Flags = Lhs::Flags | Rhs::Flags,
CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + BinaryOp::Cost
CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + ei_functor_traits<BinaryOp>::Cost
};
};
@ -230,4 +230,76 @@ MatrixBase<Derived>::cwise(const MatrixBase<OtherDerived> &other, const CustomBi
return CwiseBinaryOp<CustomBinaryOp, Derived, OtherDerived>(derived(), other.derived(), func);
}
/** \returns an expression of the coefficient-wise \< operator of *this and \a other
*
* \sa class CwiseBinaryOp
*/
template<typename Derived>
template<typename OtherDerived>
const CwiseBinaryOp<std::less<typename ei_traits<Derived>::Scalar>, Derived, OtherDerived>
MatrixBase<Derived>::cwiseLessThan(const MatrixBase<OtherDerived> &other) const
{
return cwise(other, std::less<Scalar>());
}
/** \returns an expression of the coefficient-wise \<= operator of *this and \a other
*
* \sa class CwiseBinaryOp
*/
template<typename Derived>
template<typename OtherDerived>
const CwiseBinaryOp<std::less_equal<typename ei_traits<Derived>::Scalar>, Derived, OtherDerived>
MatrixBase<Derived>::cwiseLessEqual(const MatrixBase<OtherDerived> &other) const
{
return cwise(other, std::less_equal<Scalar>());
}
/** \returns an expression of the coefficient-wise \> operator of *this and \a other
*
* \sa class CwiseBinaryOp
*/
template<typename Derived>
template<typename OtherDerived>
const CwiseBinaryOp<std::greater<typename ei_traits<Derived>::Scalar>, Derived, OtherDerived>
MatrixBase<Derived>::cwiseGreaterThan(const MatrixBase<OtherDerived> &other) const
{
return cwise(other, std::greater<Scalar>());
}
/** \returns an expression of the coefficient-wise \>= operator of *this and \a other
*
* \sa class CwiseBinaryOp
*/
template<typename Derived>
template<typename OtherDerived>
const CwiseBinaryOp<std::greater_equal<typename ei_traits<Derived>::Scalar>, Derived, OtherDerived>
MatrixBase<Derived>::cwiseGreaterEqual(const MatrixBase<OtherDerived> &other) const
{
return cwise(other, std::greater_equal<Scalar>());
}
/** \returns an expression of the coefficient-wise == operator of *this and \a other
*
* \sa class CwiseBinaryOp
*/
template<typename Derived>
template<typename OtherDerived>
const CwiseBinaryOp<std::equal_to<typename ei_traits<Derived>::Scalar>, Derived, OtherDerived>
MatrixBase<Derived>::cwiseEqualTo(const MatrixBase<OtherDerived> &other) const
{
return cwise(other, std::equal_to<Scalar>());
}
/** \returns an expression of the coefficient-wise != operator of *this and \a other
*
* \sa class CwiseBinaryOp
*/
template<typename Derived>
template<typename OtherDerived>
const CwiseBinaryOp<std::not_equal_to<typename ei_traits<Derived>::Scalar>, Derived, OtherDerived>
MatrixBase<Derived>::cwiseNotEqualTo(const MatrixBase<OtherDerived> &other) const
{
return cwise(other, std::not_equal_to<Scalar>());
}
#endif // EIGEN_CWISE_BINARY_OP_H

View File

@ -51,7 +51,7 @@ struct ei_traits<CwiseUnaryOp<UnaryOp, MatrixType> >
MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime,
Flags = MatrixType::Flags,
CoeffReadCost = MatrixType::CoeffReadCost + UnaryOp::Cost
CoeffReadCost = MatrixType::CoeffReadCost + ei_functor_traits<UnaryOp>::Cost
};
};

View File

@ -82,7 +82,7 @@ struct ei_copy_unless_matrix<Matrix<_Scalar, _Rows, _Cols, _Flags, _MaxRows, _Ma
template<typename T> struct ei_xpr_copy
{
typedef typename ei_meta_if<T::Flags & TemporaryBit,
typedef typename ei_meta_if<T::Flags & TemporaryBit,
T,
typename ei_copy_unless_matrix<T>::type
>::ret type;
@ -115,4 +115,11 @@ template<typename T, int n=1> struct ei_eval_if_needed_before_nesting
typedef typename ei_meta_if<eval, typename ei_eval_temporary<T>::type, T>::ret type;
};
template<typename T> struct ei_functor_traits
{
enum { Cost = T::Cost };
};
#endif // EIGEN_FORWARDDECLARATIONS_H

View File

@ -252,7 +252,7 @@ template<typename Derived> class MatrixBase
*/
//@{
template<typename OtherDerived>
const Product<typename ei_eval_if_needed_before_nesting<Derived, OtherDerived::ColsAtCompileTime>::type,
const Product<typename ei_eval_if_needed_before_nesting<Derived, OtherDerived::ColsAtCompileTime>::type,
typename ei_eval_if_needed_before_nesting<OtherDerived, ei_traits<Derived>::ColsAtCompileTime>::type>
operator*(const MatrixBase<OtherDerived> &other) const;
@ -354,6 +354,15 @@ template<typename Derived> class MatrixBase
RealScalar prec = precision<Scalar>()) const;
bool isOrtho(RealScalar prec = precision<Scalar>()) const;
template<typename OtherDerived>
bool operator==(const MatrixBase<OtherDerived>& other) const
{ return derived().cwiseEqualTo(other.derived()).all(); }
template<typename OtherDerived>
bool operator!=(const MatrixBase<OtherDerived>& other) const
{ return derived().cwiseNotEqualTo(other.derived()).all(); }
//@}
/// \name Special functions
//@{
template<typename NewType>
@ -390,6 +399,30 @@ template<typename Derived> class MatrixBase
const CwiseBinaryOp<ei_scalar_max_op<typename ei_traits<Derived>::Scalar>, Derived, OtherDerived>
cwiseMax(const MatrixBase<OtherDerived> &other) const;
template<typename OtherDerived>
const CwiseBinaryOp<std::less<typename ei_traits<Derived>::Scalar>, Derived, OtherDerived>
cwiseLessThan(const MatrixBase<OtherDerived> &other) const;
template<typename OtherDerived>
const CwiseBinaryOp<std::less_equal<typename ei_traits<Derived>::Scalar>, Derived, OtherDerived>
cwiseLessEqual(const MatrixBase<OtherDerived> &other) const;
template<typename OtherDerived>
const CwiseBinaryOp<std::greater<typename ei_traits<Derived>::Scalar>, Derived, OtherDerived>
cwiseGreaterThan(const MatrixBase<OtherDerived> &other) const;
template<typename OtherDerived>
const CwiseBinaryOp<std::greater_equal<typename ei_traits<Derived>::Scalar>, Derived, OtherDerived>
cwiseGreaterEqual(const MatrixBase<OtherDerived> &other) const;
template<typename OtherDerived>
const CwiseBinaryOp<std::equal_to<typename ei_traits<Derived>::Scalar>, Derived, OtherDerived>
cwiseEqualTo(const MatrixBase<OtherDerived> &other) const;
template<typename OtherDerived>
const CwiseBinaryOp<std::not_equal_to<typename ei_traits<Derived>::Scalar>, Derived, OtherDerived>
cwiseNotEqualTo(const MatrixBase<OtherDerived> &other) const;
const CwiseUnaryOp<ei_scalar_abs_op<typename ei_traits<Derived>::Scalar>, Derived> cwiseAbs() const;
const CwiseUnaryOp<ei_scalar_abs2_op<typename ei_traits<Derived>::Scalar>, Derived> cwiseAbs2() const;
const CwiseUnaryOp<ei_scalar_sqrt_op<typename ei_traits<Derived>::Scalar>, Derived> cwiseSqrt() const;
@ -419,6 +452,9 @@ template<typename Derived> class MatrixBase
typename ei_traits<Derived>::Scalar minCoeff(int* row, int* col = 0) const;
typename ei_traits<Derived>::Scalar maxCoeff(int* row, int* col = 0) const;
bool all(void) const;
bool any(void) const;
template<typename BinaryOp>
const PartialRedux<Vertical, BinaryOp, Derived>
verticalRedux(const BinaryOp& func) const;

View File

@ -5,12 +5,12 @@
//
// Eigen is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// License as published by the Free Software Foundation; either
// version 3 of the License, or (at your option) any later version.
//
// Alternatively, you can redistribute it and/or
// modify it under the terms of the GNU General Public License as
// published by the Free Software Foundation; either version 2 of
// published by the Free Software Foundation; either version 2 of
// the License, or (at your option) any later version.
//
// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
@ -18,7 +18,7 @@
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// You should have received a copy of the GNU Lesser General Public
// License and a copy of the GNU General Public License along with
// Eigen. If not, see <http://www.gnu.org/licenses/>.
@ -125,4 +125,17 @@ template<> struct NumTraits<long double>
};
};
template<> struct NumTraits<bool>
{
typedef bool Real;
typedef float FloatingPoint;
enum {
IsComplex = 0,
HasFloatingPoint = 0,
ReadCost = 1,
AddCost = 1,
MulCost = 1
};
};
#endif // EIGEN_NUMTRAITS_H

View File

@ -232,4 +232,102 @@ MatrixBase<Derived>::maxCoeff() const
return this->redux(Eigen::ei_scalar_max_op<Scalar>());
}
template<typename Derived, int UnrollCount>
struct ei_all_unroller
{
enum {
col = (UnrollCount-1) / Derived::RowsAtCompileTime,
row = (UnrollCount-1) % Derived::RowsAtCompileTime
};
static bool run(const Derived &mat)
{
return ei_all_unroller<Derived, UnrollCount-1>::run(mat) && mat.coeff(row, col);
}
};
template<typename Derived>
struct ei_all_unroller<Derived, 1>
{
static bool run(const Derived &mat) { return mat.coeff(0, 0); }
};
template<typename Derived>
struct ei_all_unroller<Derived, Dynamic>
{
static bool run(const Derived &) { return false; }
};
template<typename Derived, int UnrollCount>
struct ei_any_unroller
{
enum {
col = (UnrollCount-1) / Derived::RowsAtCompileTime,
row = (UnrollCount-1) % Derived::RowsAtCompileTime
};
static bool run(const Derived &mat)
{
return ei_any_unroller<Derived, UnrollCount-1>::run(mat) || mat.coeff(row, col);
}
};
template<typename Derived>
struct ei_any_unroller<Derived, 1>
{
static bool run(const Derived &mat) { return mat.coeff(0, 0); }
};
template<typename Derived>
struct ei_any_unroller<Derived, Dynamic>
{
static bool run(const Derived &) { return false; }
};
/** \returns true if all coefficients are true
*
* \sa MatrixBase::any()
*/
template<typename Derived>
bool MatrixBase<Derived>::all(void) const
{
if(EIGEN_UNROLLED_LOOPS
&& SizeAtCompileTime != Dynamic
&& SizeAtCompileTime <= EIGEN_UNROLLING_LIMIT)
return ei_all_unroller<Derived,
(SizeAtCompileTime>0 && SizeAtCompileTime <= EIGEN_UNROLLING_LIMIT) ?
SizeAtCompileTime : Dynamic>::run(derived());
else
{
for(int j = 0; j < cols(); j++)
for(int i = 0; i < rows(); i++)
if (!coeff(i, j)) return false;
return true;
}
}
/** \returns true if at least one coefficient is true
*
* \sa MatrixBase::any()
*/
template<typename Derived>
bool MatrixBase<Derived>::any(void) const
{
if(EIGEN_UNROLLED_LOOPS
&& SizeAtCompileTime != Dynamic
&& SizeAtCompileTime <= EIGEN_UNROLLING_LIMIT)
return ei_any_unroller<Derived,
(SizeAtCompileTime>0 && SizeAtCompileTime <= EIGEN_UNROLLING_LIMIT) ?
SizeAtCompileTime : Dynamic>::run(derived());
else
{
for(int j = 0; j < cols(); j++)
for(int i = 0; i < rows(); i++)
if (coeff(i, j)) return true;
return false;
}
}
#endif // EIGEN_REDUX_H

View File

@ -27,6 +27,9 @@
#include <iostream>
#include <cmath>
#include <cstdlib>
#include <functional>
using namespace std;
namespace Eigen {
@ -39,10 +42,10 @@ template<typename MatrixType> void cwiseops(const MatrixType& m)
{
typedef typename MatrixType::Scalar Scalar;
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> VectorType;
int rows = m.rows();
int cols = m.cols();
MatrixType m1 = MatrixType::random(rows, cols),
m2 = MatrixType::random(rows, cols),
m3(rows, cols),
@ -57,13 +60,17 @@ template<typename MatrixType> void cwiseops(const MatrixType& m)
vzero = VectorType::zero(rows);
m2 = m2.template cwise<AddIfNull<Scalar> >(mones);
VERIFY_IS_APPROX( mzero, m1-m1);
VERIFY_IS_APPROX( m2, m1+m2-m1);
VERIFY_IS_APPROX( mones, m2.cwiseQuotient(m2));
VERIFY_IS_APPROX( m1.cwiseProduct(m2), m2.cwiseProduct(m1));
VERIFY( m1.cwiseLessThan(m1.cwise(bind2nd(plus<Scalar>(), Scalar(1)))).all() );
VERIFY( !m1.cwiseLessThan(m1.cwise(bind2nd(minus<Scalar>(), Scalar(1)))).all() );
VERIFY( !m1.cwiseGreaterThan(m1.cwise(bind2nd(plus<Scalar>(), Scalar(1)))).any() );
//VERIFY_IS_APPROX( m1, m2.cwiseProduct(m1).cwiseQuotient(m2));
// VERIFY_IS_APPROX( cwiseMin(m1,m2), cwiseMin(m2,m1) );
// VERIFY_IS_APPROX( cwiseMin(m1,m1+mones), m1 );
// VERIFY_IS_APPROX( cwiseMin(m1,m1-mones), m1-mones );