Relax Ref such that Ref<MatrixXf> accepts a RowVectorXf which can be seen as a degenerate MatrixXf(1,N)

This commit is contained in:
Gael Guennebaud 2014-03-13 18:04:19 +01:00
parent 2db792852f
commit bb4b67cf39
2 changed files with 69 additions and 45 deletions

View File

@ -101,7 +101,7 @@ struct traits<Ref<_PlainObjectType, _Options, _StrideType> >
template<typename Derived> struct match {
enum {
HasDirectAccess = internal::has_direct_access<Derived>::ret,
StorageOrderMatch = PlainObjectType::IsVectorAtCompileTime || ((PlainObjectType::Flags&RowMajorBit)==(Derived::Flags&RowMajorBit)),
StorageOrderMatch = PlainObjectType::IsVectorAtCompileTime || Derived::IsVectorAtCompileTime || ((PlainObjectType::Flags&RowMajorBit)==(Derived::Flags&RowMajorBit)),
InnerStrideMatch = int(StrideType::InnerStrideAtCompileTime)==int(Dynamic)
|| int(StrideType::InnerStrideAtCompileTime)==int(Derived::InnerStrideAtCompileTime)
|| (int(StrideType::InnerStrideAtCompileTime)==0 && int(Derived::InnerStrideAtCompileTime)==1),
@ -172,8 +172,12 @@ protected:
}
else
::new (static_cast<Base*>(this)) Base(expr.data(), expr.rows(), expr.cols());
::new (&m_stride) StrideBase(StrideType::OuterStrideAtCompileTime==0?0:expr.outerStride(),
StrideType::InnerStrideAtCompileTime==0?0:expr.innerStride());
if(Expression::IsVectorAtCompileTime && (!PlainObjectType::IsVectorAtCompileTime) && ((Expression::Flags&RowMajorBit)!=(PlainObjectType::Flags&RowMajorBit)))
::new (&m_stride) StrideBase(expr.innerStride(), StrideType::InnerStrideAtCompileTime==0?0:1);
else
::new (&m_stride) StrideBase(StrideType::OuterStrideAtCompileTime==0?0:expr.outerStride(),
StrideType::InnerStrideAtCompileTime==0?0:expr.innerStride());
}
StrideBase m_stride;

View File

@ -154,59 +154,79 @@ template<typename PlainObjectType> void check_const_correctness(const PlainObjec
VERIFY( !(Ref<ConstPlainObjectType, Aligned>::Flags & LvalueBit) );
}
EIGEN_DONT_INLINE void call_ref_1(Ref<VectorXf> ) { }
EIGEN_DONT_INLINE void call_ref_2(const Ref<const VectorXf>& ) { }
EIGEN_DONT_INLINE void call_ref_3(Ref<VectorXf,0,InnerStride<> > ) { }
EIGEN_DONT_INLINE void call_ref_4(const Ref<const VectorXf,0,InnerStride<> >& ) { }
EIGEN_DONT_INLINE void call_ref_5(Ref<MatrixXf,0,OuterStride<> > ) { }
EIGEN_DONT_INLINE void call_ref_6(const Ref<const MatrixXf,0,OuterStride<> >& ) { }
template<typename B>
EIGEN_DONT_INLINE void call_ref_1(Ref<VectorXf> a, const B &b) { VERIFY_IS_EQUAL(a,b); }
template<typename B>
EIGEN_DONT_INLINE void call_ref_2(const Ref<const VectorXf>& a, const B &b) { VERIFY_IS_EQUAL(a,b); }
template<typename B>
EIGEN_DONT_INLINE void call_ref_3(Ref<VectorXf,0,InnerStride<> > a, const B &b) { VERIFY_IS_EQUAL(a,b); }
template<typename B>
EIGEN_DONT_INLINE void call_ref_4(const Ref<const VectorXf,0,InnerStride<> >& a, const B &b) { VERIFY_IS_EQUAL(a,b); }
template<typename B>
EIGEN_DONT_INLINE void call_ref_5(Ref<MatrixXf,0,OuterStride<> > a, const B &b) { VERIFY_IS_EQUAL(a,b); }
template<typename B>
EIGEN_DONT_INLINE void call_ref_6(const Ref<const MatrixXf,0,OuterStride<> >& a, const B &b) { VERIFY_IS_EQUAL(a,b); }
template<typename B>
EIGEN_DONT_INLINE void call_ref_7(Ref<Matrix<float,Dynamic,3> > a, const B &b) { VERIFY_IS_EQUAL(a,b); }
void call_ref()
{
VectorXcf ca(10);
VectorXf a(10);
VectorXcf ca = VectorXcf::Random(10);
VectorXf a = VectorXf::Random(10);
RowVectorXf b = RowVectorXf::Random(10);
MatrixXf A = MatrixXf::Random(10,10);
RowVector3f c = RowVector3f::Random();
const VectorXf& ac(a);
VectorBlock<VectorXf> ab(a,0,3);
MatrixXf A(10,10);
const VectorBlock<VectorXf> abc(a,0,3);
VERIFY_EVALUATION_COUNT( call_ref_1(a), 0);
//call_ref_1(ac); // does not compile because ac is const
VERIFY_EVALUATION_COUNT( call_ref_1(ab), 0);
VERIFY_EVALUATION_COUNT( call_ref_1(a.head(4)), 0);
VERIFY_EVALUATION_COUNT( call_ref_1(abc), 0);
VERIFY_EVALUATION_COUNT( call_ref_1(A.col(3)), 0);
// call_ref_1(A.row(3)); // does not compile because innerstride!=1
VERIFY_EVALUATION_COUNT( call_ref_3(A.row(3)), 0);
VERIFY_EVALUATION_COUNT( call_ref_4(A.row(3)), 0);
//call_ref_1(a+a); // does not compile for obvious reason
VERIFY_EVALUATION_COUNT( call_ref_1(a,a), 0);
VERIFY_EVALUATION_COUNT( call_ref_1(b,b.transpose()), 0);
// call_ref_1(ac); // does not compile because ac is const
VERIFY_EVALUATION_COUNT( call_ref_1(ab,ab), 0);
VERIFY_EVALUATION_COUNT( call_ref_1(a.head(4),a.head(4)), 0);
VERIFY_EVALUATION_COUNT( call_ref_1(abc,abc), 0);
VERIFY_EVALUATION_COUNT( call_ref_1(A.col(3),A.col(3)), 0);
// call_ref_1(A.row(3)); // does not compile because innerstride!=1
VERIFY_EVALUATION_COUNT( call_ref_3(A.row(3),A.row(3).transpose()), 0);
VERIFY_EVALUATION_COUNT( call_ref_4(A.row(3),A.row(3).transpose()), 0);
// call_ref_1(a+a); // does not compile for obvious reason
VERIFY_EVALUATION_COUNT( call_ref_2(A*A.col(1)), 1); // evaluated into a temp
VERIFY_EVALUATION_COUNT( call_ref_2(ac.head(5)), 0);
VERIFY_EVALUATION_COUNT( call_ref_2(ac), 0);
VERIFY_EVALUATION_COUNT( call_ref_2(a), 0);
VERIFY_EVALUATION_COUNT( call_ref_2(ab), 0);
VERIFY_EVALUATION_COUNT( call_ref_2(a.head(4)), 0);
VERIFY_EVALUATION_COUNT( call_ref_2(a+a), 1); // evaluated into a temp
VERIFY_EVALUATION_COUNT( call_ref_2(ca.imag()), 1); // evaluated into a temp
MatrixXf tmp = A*A.col(1);
VERIFY_EVALUATION_COUNT( call_ref_2(A*A.col(1), tmp), 1); // evaluated into a temp
VERIFY_EVALUATION_COUNT( call_ref_2(ac.head(5),ac.head(5)), 0);
VERIFY_EVALUATION_COUNT( call_ref_2(ac,ac), 0);
VERIFY_EVALUATION_COUNT( call_ref_2(a,a), 0);
VERIFY_EVALUATION_COUNT( call_ref_2(ab,ab), 0);
VERIFY_EVALUATION_COUNT( call_ref_2(a.head(4),a.head(4)), 0);
tmp = a+a;
VERIFY_EVALUATION_COUNT( call_ref_2(a+a,tmp), 1); // evaluated into a temp
VERIFY_EVALUATION_COUNT( call_ref_2(ca.imag(),ca.imag()), 1); // evaluated into a temp
VERIFY_EVALUATION_COUNT( call_ref_4(ac.head(5)), 0);
VERIFY_EVALUATION_COUNT( call_ref_4(a+a), 1); // evaluated into a temp
VERIFY_EVALUATION_COUNT( call_ref_4(ca.imag()), 0);
VERIFY_EVALUATION_COUNT( call_ref_4(ac.head(5),ac.head(5)), 0);
tmp = a+a;
VERIFY_EVALUATION_COUNT( call_ref_4(a+a,tmp), 1); // evaluated into a temp
VERIFY_EVALUATION_COUNT( call_ref_4(ca.imag(),ca.imag()), 0);
VERIFY_EVALUATION_COUNT( call_ref_5(a), 0);
VERIFY_EVALUATION_COUNT( call_ref_5(a.head(3)), 0);
VERIFY_EVALUATION_COUNT( call_ref_5(A), 0);
// call_ref_5(A.transpose()); // does not compile
VERIFY_EVALUATION_COUNT( call_ref_5(A.block(1,1,2,2)), 0);
VERIFY_EVALUATION_COUNT( call_ref_5(a,a), 0);
VERIFY_EVALUATION_COUNT( call_ref_5(a.head(3),a.head(3)), 0);
VERIFY_EVALUATION_COUNT( call_ref_5(A,A), 0);
// call_ref_5(A.transpose()); // does not compile
VERIFY_EVALUATION_COUNT( call_ref_5(A.block(1,1,2,2),A.block(1,1,2,2)), 0);
VERIFY_EVALUATION_COUNT( call_ref_5(b,b), 0); // storage order do not match, but this is a degenerate case that should work
VERIFY_EVALUATION_COUNT( call_ref_5(a.row(3),a.row(3)), 0);
VERIFY_EVALUATION_COUNT( call_ref_6(a), 0);
VERIFY_EVALUATION_COUNT( call_ref_6(a.head(3)), 0);
VERIFY_EVALUATION_COUNT( call_ref_6(A.row(3)), 1); // evaluated into a temp thouth it could be avoided by viewing it as a 1xn matrix
VERIFY_EVALUATION_COUNT( call_ref_6(A+A), 1); // evaluated into a temp
VERIFY_EVALUATION_COUNT( call_ref_6(A), 0);
VERIFY_EVALUATION_COUNT( call_ref_6(A.transpose()), 1); // evaluated into a temp because the storage orders do not match
VERIFY_EVALUATION_COUNT( call_ref_6(A.block(1,1,2,2)), 0);
VERIFY_EVALUATION_COUNT( call_ref_6(a,a), 0);
VERIFY_EVALUATION_COUNT( call_ref_6(a.head(3),a.head(3)), 0);
VERIFY_EVALUATION_COUNT( call_ref_6(A.row(3),A.row(3)), 1); // evaluated into a temp thouth it could be avoided by viewing it as a 1xn matrix
tmp = A+A;
VERIFY_EVALUATION_COUNT( call_ref_6(A+A,tmp), 1); // evaluated into a temp
VERIFY_EVALUATION_COUNT( call_ref_6(A,A), 0);
VERIFY_EVALUATION_COUNT( call_ref_6(A.transpose(),A.transpose()), 1); // evaluated into a temp because the storage orders do not match
VERIFY_EVALUATION_COUNT( call_ref_6(A.block(1,1,2,2),A.block(1,1,2,2)), 0);
VERIFY_EVALUATION_COUNT( call_ref_7(c,c), 0);
}
void test_ref()