Optimize the product of a householder-sequence with the identity, and optimize the evaluation of a HouseholderSequence to a dense matrix using faster blocked product.

This commit is contained in:
Gael Guennebaud 2018-07-11 17:16:50 +02:00
parent d193cc87f4
commit 8a5955a052
3 changed files with 35 additions and 15 deletions

View File

@ -11,7 +11,7 @@ set(CTEST_DROP_METHOD "http")
set(CTEST_DROP_SITE "manao.inria.fr") set(CTEST_DROP_SITE "manao.inria.fr")
set(CTEST_DROP_LOCATION "/CDash/submit.php?project=Eigen") set(CTEST_DROP_LOCATION "/CDash/submit.php?project=Eigen")
set(CTEST_DROP_SITE_CDASH TRUE) set(CTEST_DROP_SITE_CDASH TRUE)
set(CTEST_PROJECT_SUBPROJECTS #set(CTEST_PROJECT_SUBPROJECTS
Official #Official
Unsupported #Unsupported
) #)

View File

@ -295,6 +295,14 @@ template<typename VectorsType, typename CoeffsType, int Side> class HouseholderS
for(Index k = 0; k<cols()-vecs ; ++k) for(Index k = 0; k<cols()-vecs ; ++k)
dst.col(k).tail(rows()-k-1).setZero(); dst.col(k).tail(rows()-k-1).setZero();
} }
else if(m_length>BlockSize)
{
dst.setIdentity(rows(), rows());
if(m_reverse)
applyThisOnTheLeft(dst,workspace,true);
else
applyThisOnTheLeft(dst,workspace,true);
}
else else
{ {
dst.setIdentity(rows(), rows()); dst.setIdentity(rows(), rows());
@ -332,24 +340,27 @@ template<typename VectorsType, typename CoeffsType, int Side> class HouseholderS
} }
/** \internal */ /** \internal */
template<typename Dest> inline void applyThisOnTheLeft(Dest& dst) const template<typename Dest> inline void applyThisOnTheLeft(Dest& dst, bool inputIsIdentity = false) const
{ {
Matrix<Scalar,1,Dest::ColsAtCompileTime,RowMajor,1,Dest::MaxColsAtCompileTime> workspace; Matrix<Scalar,1,Dest::ColsAtCompileTime,RowMajor,1,Dest::MaxColsAtCompileTime> workspace;
applyThisOnTheLeft(dst, workspace); applyThisOnTheLeft(dst, workspace, inputIsIdentity);
} }
/** \internal */ /** \internal */
template<typename Dest, typename Workspace> template<typename Dest, typename Workspace>
inline void applyThisOnTheLeft(Dest& dst, Workspace& workspace) const inline void applyThisOnTheLeft(Dest& dst, Workspace& workspace, bool inputIsIdentity = false) const
{ {
const Index BlockSize = 48; if(inputIsIdentity && m_reverse)
inputIsIdentity = false;
// if the entries are large enough, then apply the reflectors by block // if the entries are large enough, then apply the reflectors by block
if(m_length>=BlockSize && dst.cols()>1) if(m_length>=BlockSize && dst.cols()>1)
{ {
for(Index i = 0; i < m_length; i+=BlockSize) // Make sure we have at least 2 useful blocks, otherwise it is point-less:
Index blockSize = m_length<2*BlockSize ? (m_length+1)/2 : BlockSize;
for(Index i = 0; i < m_length; i+=blockSize)
{ {
Index end = m_reverse ? (std::min)(m_length,i+BlockSize) : m_length-i; Index end = m_reverse ? (std::min)(m_length,i+blockSize) : m_length-i;
Index k = m_reverse ? i : (std::max)(Index(0),end-BlockSize); Index k = m_reverse ? i : (std::max)(Index(0),end-blockSize);
Index bs = end-k; Index bs = end-k;
Index start = k + m_shift; Index start = k + m_shift;
@ -359,7 +370,14 @@ template<typename VectorsType, typename CoeffsType, int Side> class HouseholderS
Side==OnTheRight ? bs : m_vectors.rows()-start, Side==OnTheRight ? bs : m_vectors.rows()-start,
Side==OnTheRight ? m_vectors.cols()-start : bs); Side==OnTheRight ? m_vectors.cols()-start : bs);
typename internal::conditional<Side==OnTheRight, Transpose<SubVectorsType>, SubVectorsType&>::type sub_vecs(sub_vecs1); typename internal::conditional<Side==OnTheRight, Transpose<SubVectorsType>, SubVectorsType&>::type sub_vecs(sub_vecs1);
Block<Dest,Dynamic,Dynamic> sub_dst(dst,dst.rows()-rows()+m_shift+k,0, rows()-m_shift-k,dst.cols());
Index dstStart = dst.rows()-rows()+m_shift+k;
Index dstRows = rows()-m_shift-k;
Block<Dest,Dynamic,Dynamic> sub_dst(dst,
dstStart,
inputIsIdentity ? dstStart : 0,
dstRows,
inputIsIdentity ? dstRows : dst.cols());
apply_block_householder_on_the_left(sub_dst, sub_vecs, m_coeffs.segment(k, bs), !m_reverse); apply_block_householder_on_the_left(sub_dst, sub_vecs, m_coeffs.segment(k, bs), !m_reverse);
} }
} }
@ -369,7 +387,8 @@ template<typename VectorsType, typename CoeffsType, int Side> class HouseholderS
for(Index k = 0; k < m_length; ++k) for(Index k = 0; k < m_length; ++k)
{ {
Index actual_k = m_reverse ? k : m_length-k-1; Index actual_k = m_reverse ? k : m_length-k-1;
dst.bottomRows(rows()-m_shift-actual_k) Index dstStart = rows()-m_shift-actual_k;
dst.bottomRightCorner(dstStart, inputIsIdentity ? dstStart : dst.cols())
.applyHouseholderOnTheLeft(essentialVector(actual_k), m_coeffs.coeff(actual_k), workspace.data()); .applyHouseholderOnTheLeft(essentialVector(actual_k), m_coeffs.coeff(actual_k), workspace.data());
} }
} }
@ -387,7 +406,7 @@ template<typename VectorsType, typename CoeffsType, int Side> class HouseholderS
{ {
typename internal::matrix_type_times_scalar_type<Scalar, OtherDerived>::Type typename internal::matrix_type_times_scalar_type<Scalar, OtherDerived>::Type
res(other.template cast<typename internal::matrix_type_times_scalar_type<Scalar,OtherDerived>::ResultScalar>()); res(other.template cast<typename internal::matrix_type_times_scalar_type<Scalar,OtherDerived>::ResultScalar>());
applyThisOnTheLeft(res); applyThisOnTheLeft(res, internal::is_identity<OtherDerived>::value && res.rows()==res.cols());
return res; return res;
} }
@ -461,6 +480,7 @@ template<typename VectorsType, typename CoeffsType, int Side> class HouseholderS
bool m_reverse; bool m_reverse;
Index m_length; Index m_length;
Index m_shift; Index m_shift;
enum { BlockSize = 48 };
}; };
/** \brief Computes the product of a matrix with a Householder sequence. /** \brief Computes the product of a matrix with a Householder sequence.

View File

@ -129,7 +129,7 @@ void matlab_cplx_real(const M& ar, const M& ai, const M& b, M& cr, M& ci)
template<typename A, typename B, typename C> template<typename A, typename B, typename C>
EIGEN_DONT_INLINE void gemm(const A& a, const B& b, C& c) EIGEN_DONT_INLINE void gemm(const A& a, const B& b, C& c)
{ {
c.noalias() += a * b; c.noalias() += a * b;
} }
int main(int argc, char ** argv) int main(int argc, char ** argv)