bug #986: add support for coefficient-based product with 0 depth.

This commit is contained in:
Gael Guennebaud 2015-04-01 13:15:23 +02:00
parent 79b4e6acaf
commit 8481dc21ea
2 changed files with 64 additions and 12 deletions

View File

@ -409,7 +409,8 @@ struct product_evaluator<Product<Lhs, Rhs, LazyProduct>, ProductTag, DenseShape,
LhsCoeffReadCost = LhsEtorType::CoeffReadCost, LhsCoeffReadCost = LhsEtorType::CoeffReadCost,
RhsCoeffReadCost = RhsEtorType::CoeffReadCost, RhsCoeffReadCost = RhsEtorType::CoeffReadCost,
CoeffReadCost = (InnerSize == Dynamic || LhsCoeffReadCost==Dynamic || RhsCoeffReadCost==Dynamic || NumTraits<Scalar>::AddCost==Dynamic || NumTraits<Scalar>::MulCost==Dynamic) ? Dynamic CoeffReadCost = InnerSize==0 ? NumTraits<Scalar>::ReadCost
: (InnerSize == Dynamic || LhsCoeffReadCost==Dynamic || RhsCoeffReadCost==Dynamic || NumTraits<Scalar>::AddCost==Dynamic || NumTraits<Scalar>::MulCost==Dynamic) ? Dynamic
: InnerSize * (NumTraits<Scalar>::MulCost + LhsCoeffReadCost + RhsCoeffReadCost) : InnerSize * (NumTraits<Scalar>::MulCost + LhsCoeffReadCost + RhsCoeffReadCost)
+ (InnerSize - 1) * NumTraits<Scalar>::AddCost, + (InnerSize - 1) * NumTraits<Scalar>::AddCost,
@ -484,7 +485,7 @@ struct product_evaluator<Product<Lhs, Rhs, LazyProduct>, ProductTag, DenseShape,
{ {
PacketScalar res; PacketScalar res;
typedef etor_product_packet_impl<Flags&RowMajorBit ? RowMajor : ColMajor, typedef etor_product_packet_impl<Flags&RowMajorBit ? RowMajor : ColMajor,
Unroll ? InnerSize-1 : Dynamic, Unroll ? InnerSize : Dynamic,
LhsEtorType, RhsEtorType, PacketScalar, LoadMode> PacketImpl; LhsEtorType, RhsEtorType, PacketScalar, LoadMode> PacketImpl;
PacketImpl::run(row, col, m_lhsImpl, m_rhsImpl, m_innerDim, res); PacketImpl::run(row, col, m_lhsImpl, m_rhsImpl, m_innerDim, res);
@ -527,7 +528,7 @@ struct etor_product_packet_impl<RowMajor, UnrollingIndex, Lhs, Rhs, Packet, Load
static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet &res) static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet &res)
{ {
etor_product_packet_impl<RowMajor, UnrollingIndex-1, Lhs, Rhs, Packet, LoadMode>::run(row, col, lhs, rhs, innerDim, res); etor_product_packet_impl<RowMajor, UnrollingIndex-1, Lhs, Rhs, Packet, LoadMode>::run(row, col, lhs, rhs, innerDim, res);
res = pmadd(pset1<Packet>(lhs.coeff(row, UnrollingIndex)), rhs.template packet<LoadMode>(UnrollingIndex, col), res); res = pmadd(pset1<Packet>(lhs.coeff(row, UnrollingIndex-1)), rhs.template packet<LoadMode>(UnrollingIndex-1, col), res);
} }
}; };
@ -537,12 +538,12 @@ struct etor_product_packet_impl<ColMajor, UnrollingIndex, Lhs, Rhs, Packet, Load
static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet &res) static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet &res)
{ {
etor_product_packet_impl<ColMajor, UnrollingIndex-1, Lhs, Rhs, Packet, LoadMode>::run(row, col, lhs, rhs, innerDim, res); etor_product_packet_impl<ColMajor, UnrollingIndex-1, Lhs, Rhs, Packet, LoadMode>::run(row, col, lhs, rhs, innerDim, res);
res = pmadd(lhs.template packet<LoadMode>(row, UnrollingIndex), pset1<Packet>(rhs.coeff(UnrollingIndex, col)), res); res = pmadd(lhs.template packet<LoadMode>(row, UnrollingIndex-1), pset1<Packet>(rhs.coeff(UnrollingIndex-1, col)), res);
} }
}; };
template<typename Lhs, typename Rhs, typename Packet, int LoadMode> template<typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct etor_product_packet_impl<RowMajor, 0, Lhs, Rhs, Packet, LoadMode> struct etor_product_packet_impl<RowMajor, 1, Lhs, Rhs, Packet, LoadMode>
{ {
static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index /*innerDim*/, Packet &res) static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index /*innerDim*/, Packet &res)
{ {
@ -551,7 +552,7 @@ struct etor_product_packet_impl<RowMajor, 0, Lhs, Rhs, Packet, LoadMode>
}; };
template<typename Lhs, typename Rhs, typename Packet, int LoadMode> template<typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct etor_product_packet_impl<ColMajor, 0, Lhs, Rhs, Packet, LoadMode> struct etor_product_packet_impl<ColMajor, 1, Lhs, Rhs, Packet, LoadMode>
{ {
static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index /*innerDim*/, Packet &res) static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index /*innerDim*/, Packet &res)
{ {
@ -559,14 +560,31 @@ struct etor_product_packet_impl<ColMajor, 0, Lhs, Rhs, Packet, LoadMode>
} }
}; };
template<typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct etor_product_packet_impl<RowMajor, 0, Lhs, Rhs, Packet, LoadMode>
{
static EIGEN_STRONG_INLINE void run(Index /*row*/, Index /*col*/, const Lhs& /*lhs*/, const Rhs& /*rhs*/, Index /*innerDim*/, Packet &res)
{
res = pset1<Packet>(0);
}
};
template<typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct etor_product_packet_impl<ColMajor, 0, Lhs, Rhs, Packet, LoadMode>
{
static EIGEN_STRONG_INLINE void run(Index /*row*/, Index /*col*/, const Lhs& /*lhs*/, const Rhs& /*rhs*/, Index /*innerDim*/, Packet &res)
{
res = pset1<Packet>(0);
}
};
template<typename Lhs, typename Rhs, typename Packet, int LoadMode> template<typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct etor_product_packet_impl<RowMajor, Dynamic, Lhs, Rhs, Packet, LoadMode> struct etor_product_packet_impl<RowMajor, Dynamic, Lhs, Rhs, Packet, LoadMode>
{ {
static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet& res) static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet& res)
{ {
eigen_assert(innerDim>0 && "you are using a non initialized matrix"); res = pset1<Packet>(0);
res = pmul(pset1<Packet>(lhs.coeff(row, 0)),rhs.template packet<LoadMode>(0, col)); for(Index i = 0; i < innerDim; ++i)
for(Index i = 1; i < innerDim; ++i)
res = pmadd(pset1<Packet>(lhs.coeff(row, i)), rhs.template packet<LoadMode>(i, col), res); res = pmadd(pset1<Packet>(lhs.coeff(row, i)), rhs.template packet<LoadMode>(i, col), res);
} }
}; };
@ -576,9 +594,8 @@ struct etor_product_packet_impl<ColMajor, Dynamic, Lhs, Rhs, Packet, LoadMode>
{ {
static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet& res) static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet& res)
{ {
eigen_assert(innerDim>0 && "you are using a non initialized matrix"); res = pset1<Packet>(0);
res = pmul(lhs.template packet<LoadMode>(row, 0), pset1<Packet>(rhs.coeff(0, col))); for(Index i = 0; i < innerDim; ++i)
for(Index i = 1; i < innerDim; ++i)
res = pmadd(lhs.template packet<LoadMode>(row, i), pset1<Packet>(rhs.coeff(i, col)), res); res = pmadd(lhs.template packet<LoadMode>(row, i), pset1<Packet>(rhs.coeff(i, col)), res);
} }
}; };

View File

@ -113,6 +113,9 @@ void mat_mat_scalar_scalar_product()
template <typename MatrixType> template <typename MatrixType>
void zero_sized_objects(const MatrixType& m) void zero_sized_objects(const MatrixType& m)
{ {
typedef typename MatrixType::Scalar Scalar;
const int PacketSize = internal::packet_traits<Scalar>::size;
const int PacketSize1 = PacketSize>1 ? PacketSize-1 : 1;
Index rows = m.rows(); Index rows = m.rows();
Index cols = m.cols(); Index cols = m.cols();
@ -132,6 +135,38 @@ void zero_sized_objects(const MatrixType& m)
res = b*a; res = b*a;
VERIFY(res.rows()==0 && res.cols()==cols); VERIFY(res.rows()==0 && res.cols()==cols);
} }
{
Matrix<Scalar,PacketSize,0> a;
Matrix<Scalar,0,1> b;
Matrix<Scalar,PacketSize,1> res;
VERIFY_IS_APPROX( (res=a*b), MatrixType::Zero(PacketSize,1) );
VERIFY_IS_APPROX( (res=a.lazyProduct(b)), MatrixType::Zero(PacketSize,1) );
}
{
Matrix<Scalar,PacketSize1,0> a;
Matrix<Scalar,0,1> b;
Matrix<Scalar,PacketSize1,1> res;
VERIFY_IS_APPROX( (res=a*b), MatrixType::Zero(PacketSize1,1) );
VERIFY_IS_APPROX( (res=a.lazyProduct(b)), MatrixType::Zero(PacketSize1,1) );
}
{
Matrix<Scalar,PacketSize,Dynamic> a(PacketSize,0);
Matrix<Scalar,Dynamic,1> b(0,1);
Matrix<Scalar,PacketSize,1> res;
VERIFY_IS_APPROX( (res=a*b), MatrixType::Zero(PacketSize,1) );
VERIFY_IS_APPROX( (res=a.lazyProduct(b)), MatrixType::Zero(PacketSize,1) );
}
{
Matrix<Scalar,PacketSize1,Dynamic> a(PacketSize1,0);
Matrix<Scalar,Dynamic,1> b(0,1);
Matrix<Scalar,PacketSize1,1> res;
VERIFY_IS_APPROX( (res=a*b), MatrixType::Zero(PacketSize1,1) );
VERIFY_IS_APPROX( (res=a.lazyProduct(b)), MatrixType::Zero(PacketSize1,1) );
}
} }
template<int> template<int>