specialize for Size==0 in order to catch user bugs and not clutter

the compiler output with an infinite recursion. Also add a #define switch
for loop unrolling.
This commit is contained in:
Benoit Jacob 2007-12-11 10:04:39 +00:00
parent 9d51572cbe
commit fc924bc7d4
4 changed files with 62 additions and 13 deletions

View File

@ -40,6 +40,17 @@ template<int UnrollCount, int Rows> struct CopyHelperUnroller
}
};
// prevent buggy user code from causing an infinite recursion
template<int UnrollCount> struct CopyHelperUnroller<UnrollCount, 0>
{
template <typename Derived1, typename Derived2>
static void run(Derived1 &dst, const Derived2 &src)
{
EIGEN_UNUSED(dst);
EIGEN_UNUSED(src);
}
};
template<int Rows> struct CopyHelperUnroller<1, Rows>
{
template <typename Derived1, typename Derived2>
@ -63,7 +74,7 @@ template<typename Scalar, typename Derived>
template<typename OtherDerived>
void MatrixBase<Scalar, Derived>::_copy_helper(const MatrixBase<Scalar, OtherDerived>& other)
{
if(SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 25)
if(EIGEN_UNROLLED_LOOPS && SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 25)
CopyHelperUnroller<SizeAtCompileTime, RowsAtCompileTime>::run(*this, other);
else
for(int i = 0; i < rows(); i++)

View File

@ -32,7 +32,7 @@ struct DotUnroller
static void run(const Derived1 &v1, const Derived2& v2, typename Derived1::Scalar &dot)
{
DotUnroller<Index-1, Size, Derived1, Derived2>::run(v1, v2, dot);
dot += v1[Index] * conj(v2[Index]);
dot += v1.read(Index) * conj(v2.read(Index));
}
};
@ -41,7 +41,7 @@ struct DotUnroller<0, Size, Derived1, Derived2>
{
static void run(const Derived1 &v1, const Derived2& v2, typename Derived1::Scalar &dot)
{
dot = v1[0] * conj(v2[0]);
dot = v1.read(0) * conj(v2.read(0));
}
};
@ -56,20 +56,32 @@ struct DotUnroller<Index, Dynamic, Derived1, Derived2>
}
};
// prevent buggy user code from causing an infinite recursion
template<int Index, typename Derived1, typename Derived2>
struct DotUnroller<Index, 0, Derived1, Derived2>
{
static void run(const Derived1 &v1, const Derived2& v2, typename Derived1::Scalar &dot)
{
EIGEN_UNUSED(v1);
EIGEN_UNUSED(v2);
EIGEN_UNUSED(dot);
}
};
template<typename Scalar, typename Derived>
template<typename OtherDerived>
Scalar MatrixBase<Scalar, Derived>::dot(const OtherDerived& other) const
{
assert(IsVector && OtherDerived::IsVector && size() == other.size());
Scalar res;
if(SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 16)
if(EIGEN_UNROLLED_LOOPS && SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 16)
DotUnroller<SizeAtCompileTime-1, SizeAtCompileTime, Derived, OtherDerived>
::run(*static_cast<const Derived*>(this), other, res);
else
{
res = (*this)[0] * conj(other[0]);
res = (*this).read(0) * conj(other.read(0));
for(int i = 1; i < size(); i++)
res += (*this)[i]* conj(other[i]);
res += (*this).read(i)* conj(other.read(i));
}
return res;
}

View File

@ -61,6 +61,21 @@ struct ProductUnroller<Index, Dynamic, Lhs, Rhs>
}
};
// prevent buggy user code from causing an infinite recursion
template<int Index, typename Lhs, typename Rhs>
struct ProductUnroller<Index, 0, Lhs, Rhs>
{
static void run(int row, int col, const Lhs& lhs, const Rhs& rhs,
typename Lhs::Scalar &res)
{
EIGEN_UNUSED(row);
EIGEN_UNUSED(col);
EIGEN_UNUSED(lhs);
EIGEN_UNUSED(rhs);
EIGEN_UNUSED(res);
}
};
template<typename Lhs, typename Rhs> class Product
: public MatrixBase<typename Lhs::Scalar, Product<Lhs, Rhs> >
{
@ -93,14 +108,15 @@ template<typename Lhs, typename Rhs> class Product
Scalar _read(int row, int col) const
{
Scalar res;
if(Lhs::ColsAtCompileTime != Dynamic && Lhs::ColsAtCompileTime <= 16)
if(EIGEN_UNROLLED_LOOPS
&& Lhs::ColsAtCompileTime != Dynamic && Lhs::ColsAtCompileTime <= 16)
ProductUnroller<Lhs::ColsAtCompileTime-1, Lhs::ColsAtCompileTime, LhsRef, RhsRef>
::run(row, col, m_lhs, m_rhs, res);
else
{
res = m_lhs(row, 0) * m_rhs(0, col);
res = m_lhs.read(row, 0) * m_rhs.read(0, col);
for(int i = 1; i < m_lhs.cols(); i++)
res += m_lhs(row, i) * m_rhs(i, col);
res += m_lhs.read(row, i) * m_rhs.read(i, col);
}
return res;
}
@ -112,7 +128,7 @@ template<typename Lhs, typename Rhs> class Product
template<typename Scalar, typename Derived>
template<typename OtherDerived>
Product<Derived, OtherDerived>
const Product<Derived, OtherDerived>
MatrixBase<Scalar, Derived>::lazyProduct(const MatrixBase<Scalar, OtherDerived> &other) const
{
return Product<Derived, OtherDerived>(ref(), other.ref());

View File

@ -31,7 +31,7 @@ template<int Index, int Rows, typename Derived> struct TraceUnroller
static void run(const Derived &mat, typename Derived::Scalar &trace)
{
TraceUnroller<Index-1, Rows, Derived>::run(mat, trace);
trace += mat(Index, Index);
trace += mat.read(Index, Index);
}
};
@ -39,7 +39,7 @@ template<int Rows, typename Derived> struct TraceUnroller<0, Rows, Derived>
{
static void run(const Derived &mat, typename Derived::Scalar &trace)
{
trace = mat(0, 0);
trace = mat.read(0, 0);
}
};
@ -52,12 +52,22 @@ template<int Index, typename Derived> struct TraceUnroller<Index, Dynamic, Deriv
}
};
// prevent buggy user code from causing an infinite recursion
template<int Index, typename Derived> struct TraceUnroller<Index, 0, Derived>
{
static void run(const Derived &mat, typename Derived::Scalar &trace)
{
EIGEN_UNUSED(mat);
EIGEN_UNUSED(trace);
}
};
template<typename Scalar, typename Derived>
Scalar MatrixBase<Scalar, Derived>::trace() const
{
assert(rows() == cols());
Scalar res;
if(RowsAtCompileTime != Dynamic && RowsAtCompileTime <= 16)
if(EIGEN_UNROLLED_LOOPS && RowsAtCompileTime != Dynamic && RowsAtCompileTime <= 16)
TraceUnroller<RowsAtCompileTime-1, RowsAtCompileTime, Derived>
::run(*static_cast<const Derived*>(this), res);
else