mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-15 07:10:37 +08:00
specialize for Size==0 in order to catch user bugs and not clutter
the compiler output with an infinite recursion. Also add a #define switch for loop unrolling.
This commit is contained in:
parent
9d51572cbe
commit
fc924bc7d4
@ -40,6 +40,17 @@ template<int UnrollCount, int Rows> struct CopyHelperUnroller
|
||||
}
|
||||
};
|
||||
|
||||
// prevent buggy user code from causing an infinite recursion
|
||||
template<int UnrollCount> struct CopyHelperUnroller<UnrollCount, 0>
|
||||
{
|
||||
template <typename Derived1, typename Derived2>
|
||||
static void run(Derived1 &dst, const Derived2 &src)
|
||||
{
|
||||
EIGEN_UNUSED(dst);
|
||||
EIGEN_UNUSED(src);
|
||||
}
|
||||
};
|
||||
|
||||
template<int Rows> struct CopyHelperUnroller<1, Rows>
|
||||
{
|
||||
template <typename Derived1, typename Derived2>
|
||||
@ -63,7 +74,7 @@ template<typename Scalar, typename Derived>
|
||||
template<typename OtherDerived>
|
||||
void MatrixBase<Scalar, Derived>::_copy_helper(const MatrixBase<Scalar, OtherDerived>& other)
|
||||
{
|
||||
if(SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 25)
|
||||
if(EIGEN_UNROLLED_LOOPS && SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 25)
|
||||
CopyHelperUnroller<SizeAtCompileTime, RowsAtCompileTime>::run(*this, other);
|
||||
else
|
||||
for(int i = 0; i < rows(); i++)
|
||||
|
@ -32,7 +32,7 @@ struct DotUnroller
|
||||
static void run(const Derived1 &v1, const Derived2& v2, typename Derived1::Scalar &dot)
|
||||
{
|
||||
DotUnroller<Index-1, Size, Derived1, Derived2>::run(v1, v2, dot);
|
||||
dot += v1[Index] * conj(v2[Index]);
|
||||
dot += v1.read(Index) * conj(v2.read(Index));
|
||||
}
|
||||
};
|
||||
|
||||
@ -41,7 +41,7 @@ struct DotUnroller<0, Size, Derived1, Derived2>
|
||||
{
|
||||
static void run(const Derived1 &v1, const Derived2& v2, typename Derived1::Scalar &dot)
|
||||
{
|
||||
dot = v1[0] * conj(v2[0]);
|
||||
dot = v1.read(0) * conj(v2.read(0));
|
||||
}
|
||||
};
|
||||
|
||||
@ -56,20 +56,32 @@ struct DotUnroller<Index, Dynamic, Derived1, Derived2>
|
||||
}
|
||||
};
|
||||
|
||||
// prevent buggy user code from causing an infinite recursion
|
||||
template<int Index, typename Derived1, typename Derived2>
|
||||
struct DotUnroller<Index, 0, Derived1, Derived2>
|
||||
{
|
||||
static void run(const Derived1 &v1, const Derived2& v2, typename Derived1::Scalar &dot)
|
||||
{
|
||||
EIGEN_UNUSED(v1);
|
||||
EIGEN_UNUSED(v2);
|
||||
EIGEN_UNUSED(dot);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Scalar, typename Derived>
|
||||
template<typename OtherDerived>
|
||||
Scalar MatrixBase<Scalar, Derived>::dot(const OtherDerived& other) const
|
||||
{
|
||||
assert(IsVector && OtherDerived::IsVector && size() == other.size());
|
||||
Scalar res;
|
||||
if(SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 16)
|
||||
if(EIGEN_UNROLLED_LOOPS && SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 16)
|
||||
DotUnroller<SizeAtCompileTime-1, SizeAtCompileTime, Derived, OtherDerived>
|
||||
::run(*static_cast<const Derived*>(this), other, res);
|
||||
else
|
||||
{
|
||||
res = (*this)[0] * conj(other[0]);
|
||||
res = (*this).read(0) * conj(other.read(0));
|
||||
for(int i = 1; i < size(); i++)
|
||||
res += (*this)[i]* conj(other[i]);
|
||||
res += (*this).read(i)* conj(other.read(i));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
@ -61,6 +61,21 @@ struct ProductUnroller<Index, Dynamic, Lhs, Rhs>
|
||||
}
|
||||
};
|
||||
|
||||
// prevent buggy user code from causing an infinite recursion
|
||||
template<int Index, typename Lhs, typename Rhs>
|
||||
struct ProductUnroller<Index, 0, Lhs, Rhs>
|
||||
{
|
||||
static void run(int row, int col, const Lhs& lhs, const Rhs& rhs,
|
||||
typename Lhs::Scalar &res)
|
||||
{
|
||||
EIGEN_UNUSED(row);
|
||||
EIGEN_UNUSED(col);
|
||||
EIGEN_UNUSED(lhs);
|
||||
EIGEN_UNUSED(rhs);
|
||||
EIGEN_UNUSED(res);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs> class Product
|
||||
: public MatrixBase<typename Lhs::Scalar, Product<Lhs, Rhs> >
|
||||
{
|
||||
@ -93,14 +108,15 @@ template<typename Lhs, typename Rhs> class Product
|
||||
Scalar _read(int row, int col) const
|
||||
{
|
||||
Scalar res;
|
||||
if(Lhs::ColsAtCompileTime != Dynamic && Lhs::ColsAtCompileTime <= 16)
|
||||
if(EIGEN_UNROLLED_LOOPS
|
||||
&& Lhs::ColsAtCompileTime != Dynamic && Lhs::ColsAtCompileTime <= 16)
|
||||
ProductUnroller<Lhs::ColsAtCompileTime-1, Lhs::ColsAtCompileTime, LhsRef, RhsRef>
|
||||
::run(row, col, m_lhs, m_rhs, res);
|
||||
else
|
||||
{
|
||||
res = m_lhs(row, 0) * m_rhs(0, col);
|
||||
res = m_lhs.read(row, 0) * m_rhs.read(0, col);
|
||||
for(int i = 1; i < m_lhs.cols(); i++)
|
||||
res += m_lhs(row, i) * m_rhs(i, col);
|
||||
res += m_lhs.read(row, i) * m_rhs.read(i, col);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@ -112,7 +128,7 @@ template<typename Lhs, typename Rhs> class Product
|
||||
|
||||
template<typename Scalar, typename Derived>
|
||||
template<typename OtherDerived>
|
||||
Product<Derived, OtherDerived>
|
||||
const Product<Derived, OtherDerived>
|
||||
MatrixBase<Scalar, Derived>::lazyProduct(const MatrixBase<Scalar, OtherDerived> &other) const
|
||||
{
|
||||
return Product<Derived, OtherDerived>(ref(), other.ref());
|
||||
|
@ -31,7 +31,7 @@ template<int Index, int Rows, typename Derived> struct TraceUnroller
|
||||
static void run(const Derived &mat, typename Derived::Scalar &trace)
|
||||
{
|
||||
TraceUnroller<Index-1, Rows, Derived>::run(mat, trace);
|
||||
trace += mat(Index, Index);
|
||||
trace += mat.read(Index, Index);
|
||||
}
|
||||
};
|
||||
|
||||
@ -39,7 +39,7 @@ template<int Rows, typename Derived> struct TraceUnroller<0, Rows, Derived>
|
||||
{
|
||||
static void run(const Derived &mat, typename Derived::Scalar &trace)
|
||||
{
|
||||
trace = mat(0, 0);
|
||||
trace = mat.read(0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
@ -52,12 +52,22 @@ template<int Index, typename Derived> struct TraceUnroller<Index, Dynamic, Deriv
|
||||
}
|
||||
};
|
||||
|
||||
// prevent buggy user code from causing an infinite recursion
|
||||
template<int Index, typename Derived> struct TraceUnroller<Index, 0, Derived>
|
||||
{
|
||||
static void run(const Derived &mat, typename Derived::Scalar &trace)
|
||||
{
|
||||
EIGEN_UNUSED(mat);
|
||||
EIGEN_UNUSED(trace);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Scalar, typename Derived>
|
||||
Scalar MatrixBase<Scalar, Derived>::trace() const
|
||||
{
|
||||
assert(rows() == cols());
|
||||
Scalar res;
|
||||
if(RowsAtCompileTime != Dynamic && RowsAtCompileTime <= 16)
|
||||
if(EIGEN_UNROLLED_LOOPS && RowsAtCompileTime != Dynamic && RowsAtCompileTime <= 16)
|
||||
TraceUnroller<RowsAtCompileTime-1, RowsAtCompileTime, Derived>
|
||||
::run(*static_cast<const Derived*>(this), res);
|
||||
else
|
||||
|
Loading…
Reference in New Issue
Block a user