Generic loop unrolling with template metaprograms. It seems to be as fast as

manually unrolling.
TODO: decide when to stop unrolling (speed vs. code size).
      maybe only unroll one loop for larger matixes.
This commit is contained in:
Michael Olbrich 2007-09-30 20:38:09 +00:00
parent 1d3743d2c5
commit 2d823d8ef6

View File

@ -28,26 +28,45 @@
#include "Util.h" #include "Util.h"
template<int count, int rows> class Loop
{
enum {
col = (count-1)/rows,
row = (count-1)%rows,
next = count-1
};
public:
template <typename Derived1, typename Derived2> static void copy(Derived1 &dst, const Derived2 &src)
{
Loop<next, rows>::copy(dst, src);
dst.write(row, col) = src.read(row, col);
}
};
template<int rows> class Loop<0, rows>
{
public:
template <typename Derived1, typename Derived2> static void copy(Derived1 &dst, const Derived2 &src)
{
EI_UNUSED(dst);
EI_UNUSED(src);
}
};
template<typename Scalar, typename Derived> class EiObject template<typename Scalar, typename Derived> class EiObject
{ {
static const int RowsAtCompileTime = Derived::RowsAtCompileTime, static const int RowsAtCompileTime = Derived::RowsAtCompileTime,
ColsAtCompileTime = Derived::ColsAtCompileTime; ColsAtCompileTime = Derived::ColsAtCompileTime,
CountAtCompileTime= RowsAtCompileTime*ColsAtCompileTime > 0 ?
RowsAtCompileTime*ColsAtCompileTime : 0;
template<typename OtherDerived> template<typename OtherDerived>
void _copy_helper(const EiObject<Scalar, OtherDerived>& other) void _copy_helper(const EiObject<Scalar, OtherDerived>& other)
{ {
if(RowsAtCompileTime == 3 && ColsAtCompileTime == 3) if ((RowsAtCompileTime != EiDynamic) &&
{ (ColsAtCompileTime != EiDynamic) &&
write(0,0) = other.read(0,0); (CountAtCompileTime <= 25))
write(1,0) = other.read(1,0); Loop<CountAtCompileTime, RowsAtCompileTime>::copy(*this, other);
write(2,0) = other.read(2,0);
write(0,1) = other.read(0,1);
write(1,1) = other.read(1,1);
write(2,1) = other.read(2,1);
write(0,2) = other.read(0,2);
write(1,2) = other.read(1,2);
write(2,2) = other.read(2,2);
}
else else
for(int i = 0; i < rows(); i++) for(int i = 0; i < rows(); i++)
for(int j = 0; j < cols(); j++) for(int j = 0; j < cols(); j++)