mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-01 18:26:24 +08:00
New implementation of Swap as discussed, reusing Assign. Makes LU run
10% faster overall.
This commit is contained in:
parent
c94be35bc8
commit
88bb2087c1
@ -25,8 +25,94 @@
|
||||
#ifndef EIGEN_SWAP_H
|
||||
#define EIGEN_SWAP_H
|
||||
|
||||
template <typename Derived, typename OtherDerived, bool IsVector = Derived::IsVectorAtCompileTime>
|
||||
struct ei_swap_selector;
|
||||
/** \class SwapWrapper
|
||||
*
|
||||
* \brief Expression which must be nested by value
|
||||
*
|
||||
* \param ExpressionType the type of the object of which we are requiring nesting-by-value
|
||||
*
|
||||
* This class is the return type of MatrixBase::nestByValue()
|
||||
* and most of the time this is the only way it is used.
|
||||
*
|
||||
* \sa MatrixBase::nestByValue()
|
||||
*/
|
||||
template<typename ExpressionType>
|
||||
struct ei_traits<SwapWrapper<ExpressionType> >
|
||||
{
|
||||
typedef typename ExpressionType::Scalar Scalar;
|
||||
enum {
|
||||
RowsAtCompileTime = ExpressionType::RowsAtCompileTime,
|
||||
ColsAtCompileTime = ExpressionType::ColsAtCompileTime,
|
||||
MaxRowsAtCompileTime = ExpressionType::MaxRowsAtCompileTime,
|
||||
MaxColsAtCompileTime = ExpressionType::MaxColsAtCompileTime,
|
||||
Flags = ExpressionType::Flags,
|
||||
CoeffReadCost = ExpressionType::CoeffReadCost
|
||||
};
|
||||
};
|
||||
|
||||
template<typename ExpressionType> class SwapWrapper
|
||||
: public MatrixBase<SwapWrapper<ExpressionType> >
|
||||
{
|
||||
public:
|
||||
|
||||
EIGEN_GENERIC_PUBLIC_INTERFACE(SwapWrapper)
|
||||
typedef typename ei_packet_traits<Scalar>::type Packet;
|
||||
|
||||
inline SwapWrapper(ExpressionType& matrix) : m_expression(matrix) {}
|
||||
|
||||
inline int rows() const { return m_expression.rows(); }
|
||||
inline int cols() const { return m_expression.cols(); }
|
||||
inline int stride() const { return m_expression.stride(); }
|
||||
|
||||
template<typename OtherDerived>
|
||||
void copyCoeff(int row, int col, const MatrixBase<OtherDerived>& other)
|
||||
{
|
||||
OtherDerived& _other = other.const_cast_derived();
|
||||
ei_internal_assert(row >= 0 && row < rows()
|
||||
&& col >= 0 && col < cols());
|
||||
Scalar tmp = m_expression.coeff(row, col);
|
||||
m_expression.coeffRef(row, col) = _other.coeff(row, col);
|
||||
_other.coeffRef(row, col) = tmp;
|
||||
}
|
||||
|
||||
template<typename OtherDerived>
|
||||
void copyCoeff(int index, const MatrixBase<OtherDerived>& other)
|
||||
{
|
||||
OtherDerived& _other = other.const_cast_derived();
|
||||
ei_internal_assert(index >= 0 && index < m_expression.size());
|
||||
Scalar tmp = m_expression.coeff(index);
|
||||
m_expression.coeffRef(index) = _other.coeff(index);
|
||||
_other.coeffRef(index) = tmp;
|
||||
}
|
||||
|
||||
template<typename OtherDerived, int LoadStoreMode>
|
||||
void copyPacket(int row, int col, const MatrixBase<OtherDerived>& other)
|
||||
{
|
||||
OtherDerived& _other = other.const_cast_derived();
|
||||
ei_internal_assert(row >= 0 && row < rows()
|
||||
&& col >= 0 && col < cols());
|
||||
Packet tmp = m_expression.template packet<LoadStoreMode>(row, col);
|
||||
m_expression.template writePacket<LoadStoreMode>(row, col,
|
||||
_other.template packet<LoadStoreMode>(row, col)
|
||||
);
|
||||
_other.template writePacket<LoadStoreMode>(row, col, tmp);
|
||||
}
|
||||
|
||||
template<typename OtherDerived, int LoadStoreMode>
|
||||
void copyPacket(int index, const MatrixBase<OtherDerived>& other)
|
||||
{
|
||||
OtherDerived& _other = other.const_cast_derived();
|
||||
ei_internal_assert(index >= 0 && index < m_expression.size());
|
||||
Packet tmp = m_expression.template packet<LoadStoreMode>(index);
|
||||
m_expression.template writePacket<LoadStoreMode>(index,
|
||||
_other.template packet<LoadStoreMode>(index)
|
||||
);
|
||||
_other.template writePacket<LoadStoreMode>(index, tmp);
|
||||
}
|
||||
|
||||
protected:
|
||||
ExpressionType m_expression;
|
||||
};
|
||||
|
||||
/** swaps *this with the expression \a other.
|
||||
*
|
||||
@ -41,51 +127,7 @@ template<typename Derived>
|
||||
template<typename OtherDerived>
|
||||
void MatrixBase<Derived>::swap(const MatrixBase<OtherDerived>& other)
|
||||
{
|
||||
MatrixBase<OtherDerived> *_other = const_cast<MatrixBase<OtherDerived>*>(&other);
|
||||
|
||||
// disable that path: it makes LU decomposition fail ! I can't see the bug though.
|
||||
if(false /*SizeAtCompileTime == Dynamic*/)
|
||||
{
|
||||
ei_swap_selector<Derived,OtherDerived>::run(derived(),other.const_cast_derived());
|
||||
}
|
||||
else // SizeAtCompileTime != Dynamic
|
||||
{
|
||||
typename Derived::Eval buf(*this);
|
||||
*this = other;
|
||||
*_other = buf;
|
||||
}
|
||||
SwapWrapper<Derived>(derived()).lazyAssign(other);
|
||||
}
|
||||
|
||||
template<typename Derived, typename OtherDerived>
|
||||
struct ei_swap_selector<Derived,OtherDerived,true>
|
||||
{
|
||||
inline static void run(Derived& src, OtherDerived& other)
|
||||
{
|
||||
typename Derived::Scalar tmp;
|
||||
ei_assert(OtherDerived::IsVectorAtCompileTime && src.size() == other.size());
|
||||
for(int i = 0; i < src.size(); i++)
|
||||
{
|
||||
tmp = src.coeff(i);
|
||||
src.coeffRef(i) = other.coeff(i);
|
||||
other.coeffRef(i) = tmp;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Derived, typename OtherDerived>
|
||||
struct ei_swap_selector<Derived,OtherDerived,false>
|
||||
{
|
||||
inline static void run(Derived& src, OtherDerived& other)
|
||||
{
|
||||
typename Derived::Scalar tmp;
|
||||
for(int j = 0; j < src.cols(); j++)
|
||||
for(int i = 0; i < src.rows(); i++)
|
||||
{
|
||||
tmp = src.coeff(i, j);
|
||||
src.coeffRef(i, j) = other.coeff(i, j);
|
||||
other.coeffRef(i, j) = tmp;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif // EIGEN_SWAP_H
|
||||
|
@ -41,6 +41,7 @@ class Matrix;
|
||||
|
||||
template<typename ExpressionType, unsigned int Added, unsigned int Removed> class Flagged;
|
||||
template<typename ExpressionType> class NestByValue;
|
||||
template<typename ExpressionType> class SwapWrapper;
|
||||
template<typename MatrixType> class Minor;
|
||||
template<typename MatrixType, int BlockRows=Dynamic, int BlockCols=Dynamic,
|
||||
int DirectAccessStatus = ei_traits<MatrixType>::Flags&DirectAccessBit> class Block;
|
||||
|
Loading…
Reference in New Issue
Block a user