trsv: add support for inner-stride!=1, reduce code instanciation, move implementation to a new products/XX.h file

This commit is contained in:
Gael Guennebaud 2010-11-05 12:43:14 +01:00
parent fe1353080e
commit 0e6c1170ab
4 changed files with 174 additions and 94 deletions

View File

@ -315,6 +315,7 @@ using std::size_t;
#include "src/Core/products/TriangularMatrixVector.h" #include "src/Core/products/TriangularMatrixVector.h"
#include "src/Core/products/TriangularMatrixMatrix.h" #include "src/Core/products/TriangularMatrixMatrix.h"
#include "src/Core/products/TriangularSolverMatrix.h" #include "src/Core/products/TriangularSolverMatrix.h"
#include "src/Core/products/TriangularSolverVector.h"
#include "src/Core/BandMatrix.h" #include "src/Core/BandMatrix.h"
#include "src/Core/BooleanRedux.h" #include "src/Core/BooleanRedux.h"

View File

@ -27,6 +27,15 @@
namespace internal { namespace internal {
// Forward declarations:
// The following two routines are implemented in the products/TriangularSolver*.h files
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
struct triangular_solve_vector;
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
struct triangular_solve_matrix;
// small helper struct extracting some traits on the underlying solver operation
template<typename Lhs, typename Rhs, int Side> template<typename Lhs, typename Rhs, int Side>
class trsolve_traits class trsolve_traits
{ {
@ -51,111 +60,40 @@ template<typename Lhs, typename Rhs,
> >
struct triangular_solver_selector; struct triangular_solver_selector;
// forward and backward substitution, row-major, rhs is a vector template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
template<typename Lhs, typename Rhs, int Mode> struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrder,1>
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor,1>
{ {
typedef typename Lhs::Scalar LhsScalar; typedef typename Lhs::Scalar LhsScalar;
typedef typename Rhs::Scalar RhsScalar; typedef typename Rhs::Scalar RhsScalar;
typedef blas_traits<Lhs> LhsProductTraits; typedef blas_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ExtractType ActualLhsType; typedef typename LhsProductTraits::ExtractType ActualLhsType;
typedef typename Lhs::Index Index; typedef Map<Matrix<RhsScalar,Dynamic,1>, Aligned> MappedRhs;
enum { static void run(const Lhs& lhs, Rhs& rhs)
IsLower = ((Mode&Lower)==Lower)
};
static void run(const Lhs& lhs, Rhs& other)
{ {
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
ActualLhsType actualLhs = LhsProductTraits::extract(lhs); ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
const Index size = lhs.cols(); // FIXME find a way to allow an inner stride if packet_traits<Scalar>::size==1
for(Index pi=IsLower ? 0 : size;
IsLower ? pi<size : pi>0;
IsLower ? pi+=PanelWidth : pi-=PanelWidth)
{
Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth);
Index r = IsLower ? pi : size - pi; // remaining size bool useRhsDirectly = Rhs::InnerStrideAtCompileTime==1 || rhs.innerStride()==1;
if (r > 0) RhsScalar* actualRhs;
if(useRhsDirectly)
{ {
// let's directly call the low level product function because: actualRhs = &rhs.coeffRef(0);
// 1 - it is faster to compile }
// 2 - it is slighlty faster at runtime else
Index startRow = IsLower ? pi : pi-actualPanelWidth; {
Index startCol = IsLower ? 0 : pi; actualRhs = ei_aligned_stack_new(RhsScalar,rhs.size());
MappedRhs(actualRhs,rhs.size()) = rhs;
general_matrix_vector_product<Index,LhsScalar,RowMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run(
actualPanelWidth, r,
&(actualLhs.const_cast_derived().coeffRef(startRow,startCol)), actualLhs.outerStride(),
&(other.coeffRef(startCol)), other.innerStride(),
&other.coeffRef(startRow), other.innerStride(),
RhsScalar(-1));
} }
for(Index k=0; k<actualPanelWidth; ++k)
triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Mode, LhsProductTraits::NeedToConjugate, StorageOrder>
::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs);
if(!useRhsDirectly)
{ {
Index i = IsLower ? pi+k : pi-k-1; rhs = MappedRhs(actualRhs, rhs.size());
Index s = IsLower ? pi : i+1; ei_aligned_stack_delete(RhsScalar, actualRhs, rhs.size());
if (k>0)
other.coeffRef(i) -= (lhs.row(i).segment(s,k).transpose().cwiseProduct(other.segment(s,k))).sum();
if(!(Mode & UnitDiag))
other.coeffRef(i) /= lhs.coeff(i,i);
}
}
}
};
// forward and backward substitution, column-major, rhs is a vector
template<typename Lhs, typename Rhs, int Mode>
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,ColMajor,1>
{
typedef typename Lhs::Scalar LhsScalar;
typedef typename Rhs::Scalar RhsScalar;
typedef blas_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ExtractType ActualLhsType;
typedef typename Lhs::Index Index;
enum {
IsLower = ((Mode&Lower)==Lower)
};
static void run(const Lhs& lhs, Rhs& other)
{
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
const Index size = lhs.cols();
for(Index pi=IsLower ? 0 : size;
IsLower ? pi<size : pi>0;
IsLower ? pi+=PanelWidth : pi-=PanelWidth)
{
Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth);
Index startBlock = IsLower ? pi : pi-actualPanelWidth;
Index endBlock = IsLower ? pi + actualPanelWidth : 0;
for(Index k=0; k<actualPanelWidth; ++k)
{
Index i = IsLower ? pi+k : pi-k-1;
if(!(Mode & UnitDiag))
other.coeffRef(i) /= lhs.coeff(i,i);
Index r = actualPanelWidth - k - 1; // remaining size
Index s = IsLower ? i+1 : i-r;
if (r>0)
other.segment(s,r) -= other.coeffRef(i) * Block<Lhs,Dynamic,1>(lhs, s, i, r, 1);
}
Index r = IsLower ? size - endBlock : startBlock; // remaining size
if (r > 0)
{
// let's directly call the low level product function because:
// 1 - it is faster to compile
// 2 - it is slighlty faster at runtime
general_matrix_vector_product<Index,LhsScalar,ColMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run(
r, actualPanelWidth,
&(actualLhs.const_cast_derived().coeffRef(endBlock,startBlock)), actualLhs.outerStride(),
&other.coeff(startBlock), other.innerStride(),
&(other.coeffRef(endBlock, 0)), other.innerStride(), RhsScalar(-1));
}
} }
} }
}; };
@ -172,8 +110,6 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,Unrolling,StorageOrder
} }
}; };
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
struct triangular_solve_matrix;
// the rhs is a matrix // the rhs is a matrix
template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder> template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder>

View File

@ -0,0 +1,138 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// Eigen is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 3 of the License, or (at your option) any later version.
//
// Alternatively, you can redistribute it and/or
// modify it under the terms of the GNU General Public License as
// published by the Free Software Foundation; either version 2 of
// the License, or (at your option) any later version.
//
// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License and a copy of the GNU General Public License along with
// Eigen. If not, see <http://www.gnu.org/licenses/>.
#ifndef EIGEN_TRIANGULAR_SOLVER_VECTOR_H
#define EIGEN_TRIANGULAR_SOLVER_VECTOR_H
namespace internal {
// forward and backward substitution, row-major, rhs is a vector
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, RowMajor>
{
enum {
IsLower = ((Mode&Lower)==Lower)
};
static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
{
typedef Map<Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
typename internal::conditional<
Conjugate,
const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
const LhsMap&>
::type cjLhs(lhs);
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
for(Index pi=IsLower ? 0 : size;
IsLower ? pi<size : pi>0;
IsLower ? pi+=PanelWidth : pi-=PanelWidth)
{
Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth);
Index r = IsLower ? pi : size - pi; // remaining size
if (r > 0)
{
// let's directly call the low level product function because:
// 1 - it is faster to compile
// 2 - it is slighlty faster at runtime
Index startRow = IsLower ? pi : pi-actualPanelWidth;
Index startCol = IsLower ? 0 : pi;
general_matrix_vector_product<Index,LhsScalar,RowMajor,Conjugate,RhsScalar,false>::run(
actualPanelWidth, r,
&(lhs.coeff(startRow,startCol)), lhsStride,
rhs + startCol, 1,
rhs + startRow, 1,
RhsScalar(-1));
}
for(Index k=0; k<actualPanelWidth; ++k)
{
Index i = IsLower ? pi+k : pi-k-1;
Index s = IsLower ? pi : i+1;
if (k>0)
rhs[i] -= (cjLhs.row(i).segment(s,k).transpose().cwiseProduct(Map<Matrix<RhsScalar,Dynamic,1> >(rhs+s,k))).sum();
if(!(Mode & UnitDiag))
rhs[i] /= lhs(i,i);
}
}
}
};
// forward and backward substitution, column-major, rhs is a vector
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, ColMajor>
{
enum {
IsLower = ((Mode&Lower)==Lower)
};
static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
{
typedef Map<Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
typename internal::conditional<Conjugate,
const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
const LhsMap&
>::type cjLhs(lhs);
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
for(Index pi=IsLower ? 0 : size;
IsLower ? pi<size : pi>0;
IsLower ? pi+=PanelWidth : pi-=PanelWidth)
{
Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth);
Index startBlock = IsLower ? pi : pi-actualPanelWidth;
Index endBlock = IsLower ? pi + actualPanelWidth : 0;
for(Index k=0; k<actualPanelWidth; ++k)
{
Index i = IsLower ? pi+k : pi-k-1;
if(!(Mode & UnitDiag))
rhs[i] /= cjLhs.coeff(i,i);
Index r = actualPanelWidth - k - 1; // remaining size
Index s = IsLower ? i+1 : i-r;
if (r>0)
Map<Matrix<RhsScalar,Dynamic,1> >(rhs+s,r) -= rhs[i] * cjLhs.col(i).segment(s,r);
}
Index r = IsLower ? size - endBlock : startBlock; // remaining size
if (r > 0)
{
// let's directly call the low level product function because:
// 1 - it is faster to compile
// 2 - it is slighlty faster at runtime
general_matrix_vector_product<Index,LhsScalar,ColMajor,Conjugate,RhsScalar,false>::run(
r, actualPanelWidth,
&(lhs.coeff(endBlock,startBlock)), lhsStride,
rhs+startBlock, 1,
rhs+endBlock, 1, RhsScalar(-1));
}
}
}
};
} // end namespace internal
#endif // EIGEN_TRIANGULAR_SOLVER_VECTOR_H

View File

@ -73,6 +73,10 @@ template<typename Scalar,int Size, int Cols> void trsolve(int size=Size,int cols
VERIFY_TRSM_ONTHERIGHT(rmLhs .template triangularView<Lower>(), cmRhs); VERIFY_TRSM_ONTHERIGHT(rmLhs .template triangularView<Lower>(), cmRhs);
VERIFY_TRSM_ONTHERIGHT(rmLhs.conjugate().template triangularView<UnitUpper>(), rmRhs); VERIFY_TRSM_ONTHERIGHT(rmLhs.conjugate().template triangularView<UnitUpper>(), rmRhs);
int c = internal::random<int>(0,cols-1);
VERIFY_TRSM(rmLhs.template triangularView<Lower>(), rmRhs.col(c));
VERIFY_TRSM(cmLhs.template triangularView<Lower>(), rmRhs.col(c));
} }
void test_product_trsolve() void test_product_trsolve()
@ -86,6 +90,7 @@ void test_product_trsolve()
CALL_SUBTEST_4((trsolve<std::complex<double>,Dynamic,Dynamic>(internal::random<int>(1,200),internal::random<int>(1,200)))); CALL_SUBTEST_4((trsolve<std::complex<double>,Dynamic,Dynamic>(internal::random<int>(1,200),internal::random<int>(1,200))));
// vectors // vectors
CALL_SUBTEST_1((trsolve<float,Dynamic,1>(internal::random<int>(1,320))));
CALL_SUBTEST_5((trsolve<std::complex<double>,Dynamic,1>(internal::random<int>(1,320)))); CALL_SUBTEST_5((trsolve<std::complex<double>,Dynamic,1>(internal::random<int>(1,320))));
CALL_SUBTEST_6((trsolve<float,1,1>())); CALL_SUBTEST_6((trsolve<float,1,1>()));
CALL_SUBTEST_7((trsolve<float,1,2>())); CALL_SUBTEST_7((trsolve<float,1,2>()));