Let KroneckerProduct exploits the recently introduced generic InnerIterator class.

This commit is contained in:
Gael Guennebaud 2014-09-29 13:37:49 +02:00
parent abd3502e9e
commit 842e31cf5c
2 changed files with 22 additions and 23 deletions

View File

@ -157,40 +157,27 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
dst.resizeNonZeros(0);
// 1 - evaluate the operands if needed:
typedef typename internal::nested_eval<Lhs,10>::type Lhs1;
typedef typename internal::nested_eval<Lhs,Dynamic>::type Lhs1;
typedef typename internal::remove_all<Lhs1>::type Lhs1Cleaned;
const Lhs1 lhs1(m_A);
typedef typename internal::nested_eval<Rhs,10>::type Rhs1;
typedef typename internal::nested_eval<Rhs,Dynamic>::type Rhs1;
typedef typename internal::remove_all<Rhs1>::type Rhs1Cleaned;
const Rhs1 rhs1(m_B);
// 2 - construct a SparseView for dense operands
typedef typename internal::conditional<internal::is_same<typename internal::traits<Lhs1Cleaned>::StorageKind,Sparse>::value, Lhs1, SparseView<const Lhs1Cleaned> >::type Lhs2;
typedef typename internal::remove_all<Lhs2>::type Lhs2Cleaned;
const Lhs2 lhs2(lhs1);
typedef typename internal::conditional<internal::is_same<typename internal::traits<Rhs1Cleaned>::StorageKind,Sparse>::value, Rhs1, SparseView<const Rhs1Cleaned> >::type Rhs2;
typedef typename internal::remove_all<Rhs2>::type Rhs2Cleaned;
const Rhs2 rhs2(rhs1);
// 3 - construct respective evaluators
typedef typename internal::evaluator<Lhs2Cleaned>::type LhsEval;
LhsEval lhsEval(lhs2);
typedef typename internal::evaluator<Rhs2Cleaned>::type RhsEval;
RhsEval rhsEval(rhs2);
typedef typename LhsEval::InnerIterator LhsInnerIterator;
typedef typename RhsEval::InnerIterator RhsInnerIterator;
// 2 - construct respective iterators
typedef InnerIterator<Lhs1Cleaned> LhsInnerIterator;
typedef InnerIterator<Rhs1Cleaned> RhsInnerIterator;
// compute number of non-zeros per innervectors of dst
{
VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
for (typename Lhs::Index kA=0; kA < m_A.outerSize(); ++kA)
for (LhsInnerIterator itA(lhsEval,kA); itA; ++itA)
for (LhsInnerIterator itA(lhs1,kA); itA; ++itA)
nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
for (typename Rhs::Index kB=0; kB < m_B.outerSize(); ++kB)
for (RhsInnerIterator itB(rhsEval,kB); itB; ++itB)
for (RhsInnerIterator itB(rhs1,kB); itB; ++itB)
nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose();
@ -201,9 +188,9 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
{
for (typename Rhs::Index kB=0; kB < m_B.outerSize(); ++kB)
{
for (LhsInnerIterator itA(lhsEval,kA); itA; ++itA)
for (LhsInnerIterator itA(lhs1,kA); itA; ++itA)
{
for (RhsInnerIterator itB(rhsEval,kB); itB; ++itB)
for (RhsInnerIterator itB(rhs1,kB); itB; ++itB)
{
const DestIndex
i = DestIndex(itA.row() * Br + itB.row()),

View File

@ -216,5 +216,17 @@ void test_kronecker_product()
sC2 = kroneckerProduct(sA,sB);
dC = kroneckerProduct(dA,dB);
VERIFY_IS_APPROX(MatrixXf(sC2),dC);
sC2 = kroneckerProduct(dA,sB);
dC = kroneckerProduct(dA,dB);
VERIFY_IS_APPROX(MatrixXf(sC2),dC);
sC2 = kroneckerProduct(sA,dB);
dC = kroneckerProduct(dA,dB);
VERIFY_IS_APPROX(MatrixXf(sC2),dC);
sC2 = kroneckerProduct(2*sA,sB);
dC = kroneckerProduct(2*dA,dB);
VERIFY_IS_APPROX(MatrixXf(sC2),dC);
}
}