bug #1009, part 1/2: make sure vector expressions expose LinearAccessBit flag.

This commit is contained in:
Gael Guennebaud 2015-11-27 10:06:07 +01:00
parent 7ddcf97da7
commit 91a7059459
2 changed files with 18 additions and 7 deletions

View File

@ -907,8 +907,8 @@ struct unary_evaluator<Replicate<ArgType, RowFactor, ColFactor> >
enum { enum {
CoeffReadCost = evaluator<ArgTypeNestedCleaned>::CoeffReadCost, CoeffReadCost = evaluator<ArgTypeNestedCleaned>::CoeffReadCost,
LinearAccessMask = XprType::IsVectorAtCompileTime ? LinearAccessBit : 0,
Flags = (evaluator<ArgTypeNestedCleaned>::Flags & HereditaryBits & ~RowMajorBit) | (traits<XprType>::Flags & RowMajorBit), Flags = (evaluator<ArgTypeNestedCleaned>::Flags & (HereditaryBits|LinearAccessMask) & ~RowMajorBit) | (traits<XprType>::Flags & RowMajorBit),
Alignment = evaluator<ArgTypeNestedCleaned>::Alignment Alignment = evaluator<ArgTypeNestedCleaned>::Alignment
}; };
@ -1149,6 +1149,7 @@ struct unary_evaluator<Reverse<ArgType, Direction> >
// FIXME enable DirectAccess with negative strides? // FIXME enable DirectAccess with negative strides?
Flags0 = evaluator<ArgType>::Flags, Flags0 = evaluator<ArgType>::Flags,
LinearAccess = ( (Direction==BothDirections) && (int(Flags0)&PacketAccessBit) ) LinearAccess = ( (Direction==BothDirections) && (int(Flags0)&PacketAccessBit) )
|| ((ReverseRow && XprType::ColsAtCompileTime==1) || (ReverseCol && XprType::RowsAtCompileTime==1))
? LinearAccessBit : 0, ? LinearAccessBit : 0,
Flags = int(Flags0) & (HereditaryBits | PacketAccessBit | LinearAccess), Flags = int(Flags0) & (HereditaryBits | PacketAccessBit | LinearAccess),
@ -1158,8 +1159,8 @@ struct unary_evaluator<Reverse<ArgType, Direction> >
EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& reverse) EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& reverse)
: m_argImpl(reverse.nestedExpression()), : m_argImpl(reverse.nestedExpression()),
m_rows(ReverseRow ? reverse.nestedExpression().rows() : 0), m_rows(ReverseRow ? reverse.nestedExpression().rows() : 1),
m_cols(ReverseCol ? reverse.nestedExpression().cols() : 0) m_cols(ReverseCol ? reverse.nestedExpression().cols() : 1)
{ } { }
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index row, Index col) const EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index row, Index col) const
@ -1233,8 +1234,9 @@ protected:
evaluator<ArgType> m_argImpl; evaluator<ArgType> m_argImpl;
// If we do not reverse rows, then we do not need to know the number of rows; same for columns // If we do not reverse rows, then we do not need to know the number of rows; same for columns
const variable_if_dynamic<Index, ReverseRow ? ArgType::RowsAtCompileTime : 0> m_rows; // Nonetheless, in this case it is important to set to 1 such that the coeff(index) method works fine for vectors.
const variable_if_dynamic<Index, ReverseCol ? ArgType::ColsAtCompileTime : 0> m_cols; const variable_if_dynamic<Index, ReverseRow ? ArgType::RowsAtCompileTime : 1> m_rows;
const variable_if_dynamic<Index, ReverseCol ? ArgType::ColsAtCompileTime : 1> m_cols;
}; };

View File

@ -484,7 +484,8 @@ struct product_evaluator<Product<Lhs, Rhs, LazyProduct>, ProductTag, DenseShape,
Flags = ((unsigned int)(LhsFlags | RhsFlags) & HereditaryBits & ~RowMajorBit) Flags = ((unsigned int)(LhsFlags | RhsFlags) & HereditaryBits & ~RowMajorBit)
| (EvalToRowMajor ? RowMajorBit : 0) | (EvalToRowMajor ? RowMajorBit : 0)
// TODO enable vectorization for mixed types // TODO enable vectorization for mixed types
| (SameType && (CanVectorizeLhs || CanVectorizeRhs) ? PacketAccessBit : 0), | (SameType && (CanVectorizeLhs || CanVectorizeRhs) ? PacketAccessBit : 0)
| (XprType::IsVectorAtCompileTime ? LinearAccessBit : 0),
Alignment = CanVectorizeLhs ? LhsAlignment Alignment = CanVectorizeLhs ? LhsAlignment
: CanVectorizeRhs ? RhsAlignment : CanVectorizeRhs ? RhsAlignment
@ -531,6 +532,14 @@ struct product_evaluator<Product<Lhs, Rhs, LazyProduct>, ProductTag, DenseShape,
return res; return res;
} }
template<int LoadMode, typename PacketType>
const PacketType packet(Index index) const
{
const Index row = RowsAtCompileTime == 1 ? 0 : index;
const Index col = RowsAtCompileTime == 1 ? index : 0;
return packet<LoadMode,PacketType>(row,col);
}
protected: protected:
const LhsNested m_lhs; const LhsNested m_lhs;
const RhsNested m_rhs; const RhsNested m_rhs;