Make cross product uses nested/nested_eval

This commit is contained in:
Gael Guennebaud 2014-08-01 14:47:33 +02:00
parent 26d2cdefd4
commit fc13b37c55
2 changed files with 33 additions and 6 deletions

View File

@ -30,8 +30,13 @@ MatrixBase<Derived>::cross(const MatrixBase<OtherDerived>& other) const
// Note that there is no need for an expression here since the compiler
// optimize such a small temporary very well (even within a complex expression)
#ifndef EIGEN_TEST_EVALUATORS
typename internal::nested<Derived,2>::type lhs(derived());
typename internal::nested<OtherDerived,2>::type rhs(other.derived());
#else
typename internal::nested_eval<Derived,2>::type lhs(derived());
typename internal::nested_eval<OtherDerived,2>::type rhs(other.derived());
#endif
return typename cross_product_return_type<OtherDerived>::type(
numext::conj(lhs.coeff(1) * rhs.coeff(2) - lhs.coeff(2) * rhs.coeff(1)),
numext::conj(lhs.coeff(2) * rhs.coeff(0) - lhs.coeff(0) * rhs.coeff(2)),
@ -76,8 +81,13 @@ MatrixBase<Derived>::cross3(const MatrixBase<OtherDerived>& other) const
EIGEN_STATIC_ASSERT_VECTOR_SPECIFIC_SIZE(Derived,4)
EIGEN_STATIC_ASSERT_VECTOR_SPECIFIC_SIZE(OtherDerived,4)
#ifndef EIGEN_TEST_EVALUATORS
typedef typename internal::nested<Derived,2>::type DerivedNested;
typedef typename internal::nested<OtherDerived,2>::type OtherDerivedNested;
#else
typedef typename internal::nested_eval<Derived,2>::type DerivedNested;
typedef typename internal::nested_eval<OtherDerived,2>::type OtherDerivedNested;
#endif
DerivedNested lhs(derived());
OtherDerivedNested rhs(other.derived());
@ -103,21 +113,29 @@ VectorwiseOp<ExpressionType,Direction>::cross(const MatrixBase<OtherDerived>& ot
EIGEN_STATIC_ASSERT_VECTOR_SPECIFIC_SIZE(OtherDerived,3)
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, typename OtherDerived::Scalar>::value),
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
#ifndef EIGEN_TEST_EVALUATORS
typename internal::nested<ExpressionType,2>::type mat(_expression());
typename internal::nested<OtherDerived,2>::type vec(other.derived());
#else
typename internal::nested_eval<ExpressionType,2>::type mat(_expression());
typename internal::nested_eval<OtherDerived,2>::type vec(other.derived());
#endif
CrossReturnType res(_expression().rows(),_expression().cols());
if(Direction==Vertical)
{
eigen_assert(CrossReturnType::RowsAtCompileTime==3 && "the matrix must have exactly 3 rows");
res.row(0) = (_expression().row(1) * other.coeff(2) - _expression().row(2) * other.coeff(1)).conjugate();
res.row(1) = (_expression().row(2) * other.coeff(0) - _expression().row(0) * other.coeff(2)).conjugate();
res.row(2) = (_expression().row(0) * other.coeff(1) - _expression().row(1) * other.coeff(0)).conjugate();
res.row(0) = (mat.row(1) * vec.coeff(2) - mat.row(2) * vec.coeff(1)).conjugate();
res.row(1) = (mat.row(2) * vec.coeff(0) - mat.row(0) * vec.coeff(2)).conjugate();
res.row(2) = (mat.row(0) * vec.coeff(1) - mat.row(1) * vec.coeff(0)).conjugate();
}
else
{
eigen_assert(CrossReturnType::ColsAtCompileTime==3 && "the matrix must have exactly 3 columns");
res.col(0) = (_expression().col(1) * other.coeff(2) - _expression().col(2) * other.coeff(1)).conjugate();
res.col(1) = (_expression().col(2) * other.coeff(0) - _expression().col(0) * other.coeff(2)).conjugate();
res.col(2) = (_expression().col(0) * other.coeff(1) - _expression().col(1) * other.coeff(0)).conjugate();
res.col(0) = (mat.col(1) * vec.coeff(2) - mat.col(2) * vec.coeff(1)).conjugate();
res.col(1) = (mat.col(2) * vec.coeff(0) - mat.col(0) * vec.coeff(2)).conjugate();
res.col(2) = (mat.col(0) * vec.coeff(1) - mat.col(1) * vec.coeff(0)).conjugate();
}
return res;
}

View File

@ -33,6 +33,7 @@ template<typename Scalar> void orthomethods_3()
VERIFY_IS_MUCH_SMALLER_THAN(v1.dot(v1.cross(v2)), Scalar(1));
VERIFY_IS_MUCH_SMALLER_THAN(v1.cross(v2).dot(v2), Scalar(1));
VERIFY_IS_MUCH_SMALLER_THAN(v2.dot(v1.cross(v2)), Scalar(1));
VERIFY_IS_MUCH_SMALLER_THAN(v1.cross(Vector3::Random()).dot(v1), Scalar(1));
Matrix3 mat3;
mat3 << v0.normalized(),
(v0.cross(v1)).normalized(),
@ -47,6 +48,13 @@ template<typename Scalar> void orthomethods_3()
int i = internal::random<int>(0,2);
mcross = mat3.colwise().cross(vec3);
VERIFY_IS_APPROX(mcross.col(i), mat3.col(i).cross(vec3));
VERIFY_IS_MUCH_SMALLER_THAN((mat3.transpose() * mat3.colwise().cross(vec3)).diagonal().cwiseAbs().sum(), Scalar(1));
VERIFY_IS_MUCH_SMALLER_THAN((mat3.transpose() * mat3.colwise().cross(Vector3::Random())).diagonal().cwiseAbs().sum(), Scalar(1));
VERIFY_IS_MUCH_SMALLER_THAN((vec3.transpose() * mat3.colwise().cross(vec3)).cwiseAbs().sum(), Scalar(1));
VERIFY_IS_MUCH_SMALLER_THAN((vec3.transpose() * Matrix3::Random().colwise().cross(vec3)).cwiseAbs().sum(), Scalar(1));
mcross = mat3.rowwise().cross(vec3);
VERIFY_IS_APPROX(mcross.row(i), mat3.row(i).cross(vec3));
@ -57,6 +65,7 @@ template<typename Scalar> void orthomethods_3()
v40.w() = v41.w() = v42.w() = 0;
v42.template head<3>() = v40.template head<3>().cross(v41.template head<3>());
VERIFY_IS_APPROX(v40.cross3(v41), v42);
VERIFY_IS_MUCH_SMALLER_THAN(v40.cross3(Vector4::Random()).dot(v40), Scalar(1));
// check mixed product
typedef Matrix<RealScalar, 3, 1> RealVector3;