mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-09 07:00:27 +08:00
trsv: add support for inner-stride!=1, reduce code instanciation, move implementation to a new products/XX.h file
This commit is contained in:
parent
fe1353080e
commit
0e6c1170ab
@ -315,6 +315,7 @@ using std::size_t;
|
|||||||
#include "src/Core/products/TriangularMatrixVector.h"
|
#include "src/Core/products/TriangularMatrixVector.h"
|
||||||
#include "src/Core/products/TriangularMatrixMatrix.h"
|
#include "src/Core/products/TriangularMatrixMatrix.h"
|
||||||
#include "src/Core/products/TriangularSolverMatrix.h"
|
#include "src/Core/products/TriangularSolverMatrix.h"
|
||||||
|
#include "src/Core/products/TriangularSolverVector.h"
|
||||||
#include "src/Core/BandMatrix.h"
|
#include "src/Core/BandMatrix.h"
|
||||||
|
|
||||||
#include "src/Core/BooleanRedux.h"
|
#include "src/Core/BooleanRedux.h"
|
||||||
|
@ -27,6 +27,15 @@
|
|||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
|
// Forward declarations:
|
||||||
|
// The following two routines are implemented in the products/TriangularSolver*.h files
|
||||||
|
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
|
||||||
|
struct triangular_solve_vector;
|
||||||
|
|
||||||
|
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
|
||||||
|
struct triangular_solve_matrix;
|
||||||
|
|
||||||
|
// small helper struct extracting some traits on the underlying solver operation
|
||||||
template<typename Lhs, typename Rhs, int Side>
|
template<typename Lhs, typename Rhs, int Side>
|
||||||
class trsolve_traits
|
class trsolve_traits
|
||||||
{
|
{
|
||||||
@ -51,111 +60,40 @@ template<typename Lhs, typename Rhs,
|
|||||||
>
|
>
|
||||||
struct triangular_solver_selector;
|
struct triangular_solver_selector;
|
||||||
|
|
||||||
// forward and backward substitution, row-major, rhs is a vector
|
template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
|
||||||
template<typename Lhs, typename Rhs, int Mode>
|
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrder,1>
|
||||||
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor,1>
|
|
||||||
{
|
{
|
||||||
typedef typename Lhs::Scalar LhsScalar;
|
typedef typename Lhs::Scalar LhsScalar;
|
||||||
typedef typename Rhs::Scalar RhsScalar;
|
typedef typename Rhs::Scalar RhsScalar;
|
||||||
typedef blas_traits<Lhs> LhsProductTraits;
|
typedef blas_traits<Lhs> LhsProductTraits;
|
||||||
typedef typename LhsProductTraits::ExtractType ActualLhsType;
|
typedef typename LhsProductTraits::ExtractType ActualLhsType;
|
||||||
typedef typename Lhs::Index Index;
|
typedef Map<Matrix<RhsScalar,Dynamic,1>, Aligned> MappedRhs;
|
||||||
enum {
|
static void run(const Lhs& lhs, Rhs& rhs)
|
||||||
IsLower = ((Mode&Lower)==Lower)
|
|
||||||
};
|
|
||||||
static void run(const Lhs& lhs, Rhs& other)
|
|
||||||
{
|
{
|
||||||
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
|
|
||||||
ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
|
ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
|
||||||
|
|
||||||
const Index size = lhs.cols();
|
// FIXME find a way to allow an inner stride if packet_traits<Scalar>::size==1
|
||||||
for(Index pi=IsLower ? 0 : size;
|
|
||||||
IsLower ? pi<size : pi>0;
|
|
||||||
IsLower ? pi+=PanelWidth : pi-=PanelWidth)
|
|
||||||
{
|
|
||||||
Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth);
|
|
||||||
|
|
||||||
Index r = IsLower ? pi : size - pi; // remaining size
|
bool useRhsDirectly = Rhs::InnerStrideAtCompileTime==1 || rhs.innerStride()==1;
|
||||||
if (r > 0)
|
RhsScalar* actualRhs;
|
||||||
|
if(useRhsDirectly)
|
||||||
{
|
{
|
||||||
// let's directly call the low level product function because:
|
actualRhs = &rhs.coeffRef(0);
|
||||||
// 1 - it is faster to compile
|
}
|
||||||
// 2 - it is slighlty faster at runtime
|
else
|
||||||
Index startRow = IsLower ? pi : pi-actualPanelWidth;
|
{
|
||||||
Index startCol = IsLower ? 0 : pi;
|
actualRhs = ei_aligned_stack_new(RhsScalar,rhs.size());
|
||||||
|
MappedRhs(actualRhs,rhs.size()) = rhs;
|
||||||
general_matrix_vector_product<Index,LhsScalar,RowMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run(
|
|
||||||
actualPanelWidth, r,
|
|
||||||
&(actualLhs.const_cast_derived().coeffRef(startRow,startCol)), actualLhs.outerStride(),
|
|
||||||
&(other.coeffRef(startCol)), other.innerStride(),
|
|
||||||
&other.coeffRef(startRow), other.innerStride(),
|
|
||||||
RhsScalar(-1));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for(Index k=0; k<actualPanelWidth; ++k)
|
|
||||||
|
triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Mode, LhsProductTraits::NeedToConjugate, StorageOrder>
|
||||||
|
::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs);
|
||||||
|
|
||||||
|
if(!useRhsDirectly)
|
||||||
{
|
{
|
||||||
Index i = IsLower ? pi+k : pi-k-1;
|
rhs = MappedRhs(actualRhs, rhs.size());
|
||||||
Index s = IsLower ? pi : i+1;
|
ei_aligned_stack_delete(RhsScalar, actualRhs, rhs.size());
|
||||||
if (k>0)
|
|
||||||
other.coeffRef(i) -= (lhs.row(i).segment(s,k).transpose().cwiseProduct(other.segment(s,k))).sum();
|
|
||||||
|
|
||||||
if(!(Mode & UnitDiag))
|
|
||||||
other.coeffRef(i) /= lhs.coeff(i,i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// forward and backward substitution, column-major, rhs is a vector
|
|
||||||
template<typename Lhs, typename Rhs, int Mode>
|
|
||||||
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,ColMajor,1>
|
|
||||||
{
|
|
||||||
typedef typename Lhs::Scalar LhsScalar;
|
|
||||||
typedef typename Rhs::Scalar RhsScalar;
|
|
||||||
typedef blas_traits<Lhs> LhsProductTraits;
|
|
||||||
typedef typename LhsProductTraits::ExtractType ActualLhsType;
|
|
||||||
typedef typename Lhs::Index Index;
|
|
||||||
enum {
|
|
||||||
IsLower = ((Mode&Lower)==Lower)
|
|
||||||
};
|
|
||||||
|
|
||||||
static void run(const Lhs& lhs, Rhs& other)
|
|
||||||
{
|
|
||||||
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
|
|
||||||
ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
|
|
||||||
|
|
||||||
const Index size = lhs.cols();
|
|
||||||
for(Index pi=IsLower ? 0 : size;
|
|
||||||
IsLower ? pi<size : pi>0;
|
|
||||||
IsLower ? pi+=PanelWidth : pi-=PanelWidth)
|
|
||||||
{
|
|
||||||
Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth);
|
|
||||||
Index startBlock = IsLower ? pi : pi-actualPanelWidth;
|
|
||||||
Index endBlock = IsLower ? pi + actualPanelWidth : 0;
|
|
||||||
|
|
||||||
for(Index k=0; k<actualPanelWidth; ++k)
|
|
||||||
{
|
|
||||||
Index i = IsLower ? pi+k : pi-k-1;
|
|
||||||
if(!(Mode & UnitDiag))
|
|
||||||
other.coeffRef(i) /= lhs.coeff(i,i);
|
|
||||||
|
|
||||||
Index r = actualPanelWidth - k - 1; // remaining size
|
|
||||||
Index s = IsLower ? i+1 : i-r;
|
|
||||||
if (r>0)
|
|
||||||
other.segment(s,r) -= other.coeffRef(i) * Block<Lhs,Dynamic,1>(lhs, s, i, r, 1);
|
|
||||||
}
|
|
||||||
Index r = IsLower ? size - endBlock : startBlock; // remaining size
|
|
||||||
if (r > 0)
|
|
||||||
{
|
|
||||||
// let's directly call the low level product function because:
|
|
||||||
// 1 - it is faster to compile
|
|
||||||
// 2 - it is slighlty faster at runtime
|
|
||||||
general_matrix_vector_product<Index,LhsScalar,ColMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run(
|
|
||||||
r, actualPanelWidth,
|
|
||||||
&(actualLhs.const_cast_derived().coeffRef(endBlock,startBlock)), actualLhs.outerStride(),
|
|
||||||
&other.coeff(startBlock), other.innerStride(),
|
|
||||||
&(other.coeffRef(endBlock, 0)), other.innerStride(), RhsScalar(-1));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -172,8 +110,6 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,Unrolling,StorageOrder
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
|
|
||||||
struct triangular_solve_matrix;
|
|
||||||
|
|
||||||
// the rhs is a matrix
|
// the rhs is a matrix
|
||||||
template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder>
|
template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder>
|
||||||
|
138
Eigen/src/Core/products/TriangularSolverVector.h
Normal file
138
Eigen/src/Core/products/TriangularSolverVector.h
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
|
// for linear algebra.
|
||||||
|
//
|
||||||
|
// Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||||
|
//
|
||||||
|
// Eigen is free software; you can redistribute it and/or
|
||||||
|
// modify it under the terms of the GNU Lesser General Public
|
||||||
|
// License as published by the Free Software Foundation; either
|
||||||
|
// version 3 of the License, or (at your option) any later version.
|
||||||
|
//
|
||||||
|
// Alternatively, you can redistribute it and/or
|
||||||
|
// modify it under the terms of the GNU General Public License as
|
||||||
|
// published by the Free Software Foundation; either version 2 of
|
||||||
|
// the License, or (at your option) any later version.
|
||||||
|
//
|
||||||
|
// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
|
||||||
|
// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||||
|
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Lesser General Public
|
||||||
|
// License and a copy of the GNU General Public License along with
|
||||||
|
// Eigen. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
#ifndef EIGEN_TRIANGULAR_SOLVER_VECTOR_H
|
||||||
|
#define EIGEN_TRIANGULAR_SOLVER_VECTOR_H
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
// forward and backward substitution, row-major, rhs is a vector
|
||||||
|
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
|
||||||
|
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, RowMajor>
|
||||||
|
{
|
||||||
|
enum {
|
||||||
|
IsLower = ((Mode&Lower)==Lower)
|
||||||
|
};
|
||||||
|
static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
|
||||||
|
{
|
||||||
|
typedef Map<Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
|
||||||
|
const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
|
||||||
|
typename internal::conditional<
|
||||||
|
Conjugate,
|
||||||
|
const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
|
||||||
|
const LhsMap&>
|
||||||
|
::type cjLhs(lhs);
|
||||||
|
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
|
||||||
|
for(Index pi=IsLower ? 0 : size;
|
||||||
|
IsLower ? pi<size : pi>0;
|
||||||
|
IsLower ? pi+=PanelWidth : pi-=PanelWidth)
|
||||||
|
{
|
||||||
|
Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth);
|
||||||
|
|
||||||
|
Index r = IsLower ? pi : size - pi; // remaining size
|
||||||
|
if (r > 0)
|
||||||
|
{
|
||||||
|
// let's directly call the low level product function because:
|
||||||
|
// 1 - it is faster to compile
|
||||||
|
// 2 - it is slighlty faster at runtime
|
||||||
|
Index startRow = IsLower ? pi : pi-actualPanelWidth;
|
||||||
|
Index startCol = IsLower ? 0 : pi;
|
||||||
|
|
||||||
|
general_matrix_vector_product<Index,LhsScalar,RowMajor,Conjugate,RhsScalar,false>::run(
|
||||||
|
actualPanelWidth, r,
|
||||||
|
&(lhs.coeff(startRow,startCol)), lhsStride,
|
||||||
|
rhs + startCol, 1,
|
||||||
|
rhs + startRow, 1,
|
||||||
|
RhsScalar(-1));
|
||||||
|
}
|
||||||
|
|
||||||
|
for(Index k=0; k<actualPanelWidth; ++k)
|
||||||
|
{
|
||||||
|
Index i = IsLower ? pi+k : pi-k-1;
|
||||||
|
Index s = IsLower ? pi : i+1;
|
||||||
|
if (k>0)
|
||||||
|
rhs[i] -= (cjLhs.row(i).segment(s,k).transpose().cwiseProduct(Map<Matrix<RhsScalar,Dynamic,1> >(rhs+s,k))).sum();
|
||||||
|
|
||||||
|
if(!(Mode & UnitDiag))
|
||||||
|
rhs[i] /= lhs(i,i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// forward and backward substitution, column-major, rhs is a vector
|
||||||
|
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
|
||||||
|
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, ColMajor>
|
||||||
|
{
|
||||||
|
enum {
|
||||||
|
IsLower = ((Mode&Lower)==Lower)
|
||||||
|
};
|
||||||
|
static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
|
||||||
|
{
|
||||||
|
typedef Map<Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
|
||||||
|
const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
|
||||||
|
typename internal::conditional<Conjugate,
|
||||||
|
const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
|
||||||
|
const LhsMap&
|
||||||
|
>::type cjLhs(lhs);
|
||||||
|
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
|
||||||
|
|
||||||
|
for(Index pi=IsLower ? 0 : size;
|
||||||
|
IsLower ? pi<size : pi>0;
|
||||||
|
IsLower ? pi+=PanelWidth : pi-=PanelWidth)
|
||||||
|
{
|
||||||
|
Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth);
|
||||||
|
Index startBlock = IsLower ? pi : pi-actualPanelWidth;
|
||||||
|
Index endBlock = IsLower ? pi + actualPanelWidth : 0;
|
||||||
|
|
||||||
|
for(Index k=0; k<actualPanelWidth; ++k)
|
||||||
|
{
|
||||||
|
Index i = IsLower ? pi+k : pi-k-1;
|
||||||
|
if(!(Mode & UnitDiag))
|
||||||
|
rhs[i] /= cjLhs.coeff(i,i);
|
||||||
|
|
||||||
|
Index r = actualPanelWidth - k - 1; // remaining size
|
||||||
|
Index s = IsLower ? i+1 : i-r;
|
||||||
|
if (r>0)
|
||||||
|
Map<Matrix<RhsScalar,Dynamic,1> >(rhs+s,r) -= rhs[i] * cjLhs.col(i).segment(s,r);
|
||||||
|
}
|
||||||
|
Index r = IsLower ? size - endBlock : startBlock; // remaining size
|
||||||
|
if (r > 0)
|
||||||
|
{
|
||||||
|
// let's directly call the low level product function because:
|
||||||
|
// 1 - it is faster to compile
|
||||||
|
// 2 - it is slighlty faster at runtime
|
||||||
|
general_matrix_vector_product<Index,LhsScalar,ColMajor,Conjugate,RhsScalar,false>::run(
|
||||||
|
r, actualPanelWidth,
|
||||||
|
&(lhs.coeff(endBlock,startBlock)), lhsStride,
|
||||||
|
rhs+startBlock, 1,
|
||||||
|
rhs+endBlock, 1, RhsScalar(-1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
|
||||||
|
#endif // EIGEN_TRIANGULAR_SOLVER_VECTOR_H
|
@ -73,6 +73,10 @@ template<typename Scalar,int Size, int Cols> void trsolve(int size=Size,int cols
|
|||||||
|
|
||||||
VERIFY_TRSM_ONTHERIGHT(rmLhs .template triangularView<Lower>(), cmRhs);
|
VERIFY_TRSM_ONTHERIGHT(rmLhs .template triangularView<Lower>(), cmRhs);
|
||||||
VERIFY_TRSM_ONTHERIGHT(rmLhs.conjugate().template triangularView<UnitUpper>(), rmRhs);
|
VERIFY_TRSM_ONTHERIGHT(rmLhs.conjugate().template triangularView<UnitUpper>(), rmRhs);
|
||||||
|
|
||||||
|
int c = internal::random<int>(0,cols-1);
|
||||||
|
VERIFY_TRSM(rmLhs.template triangularView<Lower>(), rmRhs.col(c));
|
||||||
|
VERIFY_TRSM(cmLhs.template triangularView<Lower>(), rmRhs.col(c));
|
||||||
}
|
}
|
||||||
|
|
||||||
void test_product_trsolve()
|
void test_product_trsolve()
|
||||||
@ -86,6 +90,7 @@ void test_product_trsolve()
|
|||||||
CALL_SUBTEST_4((trsolve<std::complex<double>,Dynamic,Dynamic>(internal::random<int>(1,200),internal::random<int>(1,200))));
|
CALL_SUBTEST_4((trsolve<std::complex<double>,Dynamic,Dynamic>(internal::random<int>(1,200),internal::random<int>(1,200))));
|
||||||
|
|
||||||
// vectors
|
// vectors
|
||||||
|
CALL_SUBTEST_1((trsolve<float,Dynamic,1>(internal::random<int>(1,320))));
|
||||||
CALL_SUBTEST_5((trsolve<std::complex<double>,Dynamic,1>(internal::random<int>(1,320))));
|
CALL_SUBTEST_5((trsolve<std::complex<double>,Dynamic,1>(internal::random<int>(1,320))));
|
||||||
CALL_SUBTEST_6((trsolve<float,1,1>()));
|
CALL_SUBTEST_6((trsolve<float,1,1>()));
|
||||||
CALL_SUBTEST_7((trsolve<float,1,2>()));
|
CALL_SUBTEST_7((trsolve<float,1,2>()));
|
||||||
|
Loading…
Reference in New Issue
Block a user