mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-15 07:10:37 +08:00
add explicit "on the right" triangular solving,
=> no temporary when the rhs/unknows is row major
This commit is contained in:
parent
62d9b9b7b5
commit
ff20a2ba94
@ -27,6 +27,7 @@
|
||||
|
||||
template<typename Lhs, typename Rhs,
|
||||
int Mode, // can be Upper/Lower | UnitDiag
|
||||
int Side, // can be OnTheLeft/OnTheRight
|
||||
int Unrolling = Rhs::IsVectorAtCompileTime && Rhs::SizeAtCompileTime <= 8 // FIXME
|
||||
? CompleteUnrolling : NoUnrolling,
|
||||
int StorageOrder = int(Lhs::Flags) & RowMajorBit,
|
||||
@ -36,7 +37,7 @@ struct ei_triangular_solver_selector;
|
||||
|
||||
// forward and backward substitution, row-major, rhs is a vector
|
||||
template<typename Lhs, typename Rhs, int Mode>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor,1>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor,1>
|
||||
{
|
||||
typedef typename Rhs::Scalar Scalar;
|
||||
typedef ei_blas_traits<Lhs> LhsProductTraits;
|
||||
@ -89,7 +90,7 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor,1>
|
||||
|
||||
// forward and backward substitution, column-major, rhs is a vector
|
||||
template<typename Lhs, typename Rhs, int Mode>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor,1>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,ColMajor,1>
|
||||
{
|
||||
typedef typename Rhs::Scalar Scalar;
|
||||
typedef typename ei_packet_traits<Scalar>::type Packet;
|
||||
@ -142,12 +143,12 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor,1>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, int Mode>
|
||||
template <typename Scalar, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
|
||||
struct ei_triangular_solve_matrix;
|
||||
|
||||
// the rhs is a matrix
|
||||
template<typename Lhs, typename Rhs, int Mode, int StorageOrder, int RhsCols>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,StorageOrder,RhsCols>
|
||||
template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder, int RhsCols>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,StorageOrder,RhsCols>
|
||||
{
|
||||
typedef typename Rhs::Scalar Scalar;
|
||||
typedef ei_blas_traits<Lhs> LhsProductTraits;
|
||||
@ -155,7 +156,8 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,StorageOrder,RhsCo
|
||||
static void run(const Lhs& lhs, Rhs& rhs)
|
||||
{
|
||||
const ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
|
||||
ei_triangular_solve_matrix<Scalar,StorageOrder,LhsProductTraits::NeedToConjugate,Rhs::Flags&RowMajorBit,Mode>
|
||||
ei_triangular_solve_matrix<Scalar,Side,Mode,LhsProductTraits::NeedToConjugate,StorageOrder,
|
||||
Rhs::Flags&RowMajorBit>
|
||||
::run(lhs.rows(), rhs.cols(), &actualLhs.coeff(0,0), actualLhs.stride(), &rhs.coeffRef(0,0), rhs.stride());
|
||||
}
|
||||
};
|
||||
@ -194,7 +196,7 @@ struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder,1> {
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,StorageOrder,1> {
|
||||
static void run(const Lhs& lhs, Rhs& rhs)
|
||||
{ ei_triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
|
||||
};
|
||||
@ -213,7 +215,7 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder
|
||||
* See TriangularView:solve() for the details.
|
||||
*/
|
||||
template<typename MatrixType, unsigned int Mode>
|
||||
template<typename RhsDerived>
|
||||
template<int Side, typename RhsDerived>
|
||||
void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>& _rhs) const
|
||||
{
|
||||
RhsDerived& rhs = _rhs.const_cast_derived();
|
||||
@ -228,7 +230,7 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
|
||||
RhsCopy rhsCopy(rhs);
|
||||
|
||||
ei_triangular_solver_selector<MatrixType, typename ei_unref<RhsCopy>::type,
|
||||
Mode>::run(_expression(), rhsCopy);
|
||||
Side, Mode>::run(_expression(), rhsCopy);
|
||||
|
||||
if (copy)
|
||||
rhs = rhsCopy;
|
||||
@ -266,12 +268,12 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
|
||||
* \sa TriangularView::solveInPlace()
|
||||
*/
|
||||
template<typename Derived, unsigned int Mode>
|
||||
template<typename RhsDerived>
|
||||
template<int Side, typename RhsDerived>
|
||||
typename ei_plain_matrix_type_column_major<RhsDerived>::type
|
||||
TriangularView<Derived,Mode>::solve(const MatrixBase<RhsDerived>& rhs) const
|
||||
{
|
||||
typename ei_plain_matrix_type_column_major<RhsDerived>::type res(rhs);
|
||||
solveInPlace(res);
|
||||
solveInPlace<Side>(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -269,13 +269,23 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
|
||||
(lhs.derived(),rhs.m_matrix);
|
||||
}
|
||||
|
||||
template<typename OtherDerived>
|
||||
|
||||
template<int Side, typename OtherDerived>
|
||||
typename ei_plain_matrix_type_column_major<OtherDerived>::type
|
||||
solve(const MatrixBase<OtherDerived>& other) const;
|
||||
|
||||
template<typename OtherDerived>
|
||||
template<int Side, typename OtherDerived>
|
||||
void solveInPlace(const MatrixBase<OtherDerived>& other) const;
|
||||
|
||||
template<typename OtherDerived>
|
||||
typename ei_plain_matrix_type_column_major<OtherDerived>::type
|
||||
solve(const MatrixBase<OtherDerived>& other) const
|
||||
{ return solve<OnTheLeft>(other); }
|
||||
|
||||
template<typename OtherDerived>
|
||||
void solveInPlace(const MatrixBase<OtherDerived>& other) const
|
||||
{ return solveInPlace<OnTheLeft>(other); }
|
||||
|
||||
template<typename OtherDerived>
|
||||
void swap(const TriangularBase<OtherDerived>& other)
|
||||
{
|
||||
|
@ -347,23 +347,31 @@ struct ei_gebp_kernel
|
||||
//
|
||||
// 32 33 34 35 ...
|
||||
// 36 36 38 39 ...
|
||||
template<typename Scalar, int mr, int StorageOrder, bool Conjugate>
|
||||
template<typename Scalar, int mr, int StorageOrder, bool Conjugate, bool PanelMode>
|
||||
struct ei_gemm_pack_lhs
|
||||
{
|
||||
void operator()(Scalar* blockA, const EIGEN_RESTRICT Scalar* _lhs, int lhsStride, int depth, int rows)
|
||||
void operator()(Scalar* blockA, const EIGEN_RESTRICT Scalar* _lhs, int lhsStride, int depth, int rows,
|
||||
int stride=0, int offset=0)
|
||||
{
|
||||
ei_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
|
||||
ei_conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
|
||||
ei_const_blas_data_mapper<Scalar, StorageOrder> lhs(_lhs,lhsStride);
|
||||
int count = 0;
|
||||
const int peeled_mc = (rows/mr)*mr;
|
||||
for(int i=0; i<peeled_mc; i+=mr)
|
||||
{
|
||||
if(PanelMode) count += mr * offset;
|
||||
for(int k=0; k<depth; k++)
|
||||
for(int w=0; w<mr; w++)
|
||||
blockA[count++] = cj(lhs(i+w, k));
|
||||
if(PanelMode) count += mr * (stride-offset-depth);
|
||||
}
|
||||
for(int i=peeled_mc; i<rows; i++)
|
||||
{
|
||||
if(PanelMode) count += offset;
|
||||
for(int k=0; k<depth; k++)
|
||||
blockA[count++] = cj(lhs(i, k));
|
||||
if(PanelMode) count += (stride-offset-depth);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -26,34 +26,42 @@
|
||||
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
|
||||
|
||||
// if the rhs is row major, we have to evaluate it in a temporary colmajor matrix
|
||||
template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int Mode>
|
||||
struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,RowMajor,Mode>
|
||||
template <typename Scalar, int Side, int Mode, bool Conjugate, int TriStorageOrder>
|
||||
struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,RowMajor>
|
||||
{
|
||||
static EIGEN_DONT_INLINE void run(
|
||||
int size, int cols,
|
||||
const Scalar* lhs, int lhsStride,
|
||||
Scalar* _rhs, int rhsStride)
|
||||
const Scalar* tri, int triStride,
|
||||
Scalar* _other, int otherStride)
|
||||
{
|
||||
Map<Matrix<Scalar,Dynamic,Dynamic> > rhs(_rhs, rhsStride, cols);
|
||||
Matrix<Scalar,Dynamic,Dynamic> aux = rhs.block(0,0,size,cols);
|
||||
ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,Mode>
|
||||
::run(size, cols, lhs, lhsStride, aux.data(), aux.stride());
|
||||
rhs.block(0,0,size,cols) = aux;
|
||||
|
||||
ei_triangular_solve_matrix<
|
||||
Scalar, Side==OnTheLeft?OnTheRight:OnTheLeft,
|
||||
(Mode&UnitDiagBit) | (Mode&UpperTriangular) ? LowerTriangular : UpperTriangular,
|
||||
!Conjugate, TriStorageOrder, ColMajor>
|
||||
::run(size, cols, tri, triStride, _other, otherStride);
|
||||
|
||||
// Map<Matrix<Scalar,Dynamic,Dynamic> > other(_other, otherStride, cols);
|
||||
// Matrix<Scalar,Dynamic,Dynamic> aux = other.block(0,0,size,cols);
|
||||
// ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,ColMajor>
|
||||
// ::run(size, cols, tri, triStride, aux.data(), aux.stride());
|
||||
// other.block(0,0,size,cols) = aux;
|
||||
}
|
||||
};
|
||||
|
||||
/* Optimized triangular solver with multiple right hand side (_TRSM)
|
||||
*/
|
||||
template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int Mode>
|
||||
struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,Mode>
|
||||
template <typename Scalar, int Mode, bool Conjugate, int TriStorageOrder>
|
||||
struct ei_triangular_solve_matrix<Scalar,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>
|
||||
{
|
||||
static EIGEN_DONT_INLINE void run(
|
||||
int size, int cols,
|
||||
const Scalar* _lhs, int lhsStride,
|
||||
Scalar* _rhs, int rhsStride)
|
||||
int size, int otherSize,
|
||||
const Scalar* _tri, int triStride,
|
||||
Scalar* _other, int otherStride)
|
||||
{
|
||||
ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
|
||||
ei_blas_data_mapper <Scalar, ColMajor> rhs(_rhs,rhsStride);
|
||||
int cols = otherSize;
|
||||
ei_const_blas_data_mapper<Scalar, TriStorageOrder> tri(_tri,triStride);
|
||||
ei_blas_data_mapper<Scalar, ColMajor> other(_other,otherStride);
|
||||
|
||||
typedef ei_product_blocking_traits<Scalar> Blocking;
|
||||
enum {
|
||||
@ -67,9 +75,9 @@ struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,M
|
||||
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
||||
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize);
|
||||
|
||||
ei_conj_if<ConjugateLhs> conj;
|
||||
ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,false> > gebp_kernel;
|
||||
ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder> pack_lhs;
|
||||
ei_conj_if<Conjugate> conj;
|
||||
ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<Conjugate,false> > gebp_kernel;
|
||||
ei_gemm_pack_lhs<Scalar,Blocking::mr,TriStorageOrder> pack_lhs;
|
||||
|
||||
for(int k2=IsLowerTriangular ? 0 : size;
|
||||
IsLowerTriangular ? k2<size : k2>0;
|
||||
@ -103,25 +111,25 @@ struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,M
|
||||
int s = IsLowerTriangular ? k2+k1 : i+1;
|
||||
int rs = actualPanelWidth - k - 1; // remaining size
|
||||
|
||||
Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(lhs(i,i));
|
||||
Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(tri(i,i));
|
||||
for (int j=0; j<cols; ++j)
|
||||
{
|
||||
if (LhsStorageOrder==RowMajor)
|
||||
if (TriStorageOrder==RowMajor)
|
||||
{
|
||||
Scalar b = 0;
|
||||
const Scalar* l = &lhs(i,s);
|
||||
Scalar* r = &rhs(s,j);
|
||||
const Scalar* l = &tri(i,s);
|
||||
Scalar* r = &other(s,j);
|
||||
for (int i3=0; i3<k; ++i3)
|
||||
b += conj(l[i3]) * r[i3];
|
||||
|
||||
rhs(i,j) = (rhs(i,j) - b)*a;
|
||||
other(i,j) = (other(i,j) - b)*a;
|
||||
}
|
||||
else
|
||||
{
|
||||
int s = IsLowerTriangular ? i+1 : i-rs;
|
||||
Scalar b = (rhs(i,j) *= a);
|
||||
Scalar* r = &rhs(s,j);
|
||||
const Scalar* l = &lhs(s,i);
|
||||
Scalar b = (other(i,j) *= a);
|
||||
Scalar* r = &other(s,j);
|
||||
const Scalar* l = &tri(s,i);
|
||||
for (int i3=0;i3<rs;++i3)
|
||||
r[i3] -= b * conj(l[i3]);
|
||||
}
|
||||
@ -132,18 +140,18 @@ struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,M
|
||||
int startBlock = IsLowerTriangular ? k2+k1 : k2-k1-actualPanelWidth;
|
||||
int blockBOffset = IsLowerTriangular ? k1 : lengthTarget;
|
||||
|
||||
// update the respective rows of B from rhs
|
||||
// update the respective rows of B from other
|
||||
ei_gemm_pack_rhs<Scalar, Blocking::nr, ColMajor, true>()
|
||||
(blockB, _rhs+startBlock, rhsStride, -1, actualPanelWidth, cols, actual_kc, blockBOffset);
|
||||
(blockB, _other+startBlock, otherStride, -1, actualPanelWidth, cols, actual_kc, blockBOffset);
|
||||
|
||||
// GEBP
|
||||
if (lengthTarget>0)
|
||||
{
|
||||
int startTarget = IsLowerTriangular ? k2+k1+actualPanelWidth : k2-actual_kc;
|
||||
|
||||
pack_lhs(blockA, &lhs(startTarget,startBlock), lhsStride, actualPanelWidth, lengthTarget);
|
||||
pack_lhs(blockA, &tri(startTarget,startBlock), triStride, actualPanelWidth, lengthTarget);
|
||||
|
||||
gebp_kernel(_rhs+startTarget, rhsStride, blockA, blockB, lengthTarget, actualPanelWidth, cols,
|
||||
gebp_kernel(_other+startTarget, otherStride, blockA, blockB, lengthTarget, actualPanelWidth, cols,
|
||||
actualPanelWidth, actual_kc, 0, blockBOffset*Blocking::PacketSize);
|
||||
}
|
||||
}
|
||||
@ -158,9 +166,9 @@ struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,M
|
||||
const int actual_mc = std::min(mc,end-i2);
|
||||
if (actual_mc>0)
|
||||
{
|
||||
pack_lhs(blockA, &lhs(i2, IsLowerTriangular ? k2 : k2-kc), lhsStride, actual_kc, actual_mc);
|
||||
pack_lhs(blockA, &tri(i2, IsLowerTriangular ? k2 : k2-kc), triStride, actual_kc, actual_mc);
|
||||
|
||||
gebp_kernel(_rhs+i2, rhsStride, blockA, blockB, actual_mc, actual_kc, cols);
|
||||
gebp_kernel(_other+i2, otherStride, blockA, blockB, actual_mc, actual_kc, cols);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -171,4 +179,141 @@ struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,M
|
||||
}
|
||||
};
|
||||
|
||||
/* Optimized triangular solver with multiple left hand sides and the trinagular matrix on the right
|
||||
*/
|
||||
template <typename Scalar, int Mode, bool Conjugate, int TriStorageOrder>
|
||||
struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor>
|
||||
{
|
||||
static EIGEN_DONT_INLINE void run(
|
||||
int size, int otherSize,
|
||||
const Scalar* _tri, int triStride,
|
||||
Scalar* _other, int otherStride)
|
||||
{
|
||||
int rows = otherSize;
|
||||
// ei_const_blas_data_mapper<Scalar, TriStorageOrder> rhs(_tri,triStride);
|
||||
// ei_blas_data_mapper<Scalar, ColMajor> lhs(_other,otherStride);
|
||||
|
||||
Map<Matrix<Scalar,Dynamic,Dynamic,TriStorageOrder> > rhs(_tri,size,size);
|
||||
Map<Matrix<Scalar,Dynamic,Dynamic,ColMajor> > lhs(_other,rows,size);
|
||||
|
||||
typedef ei_product_blocking_traits<Scalar> Blocking;
|
||||
enum {
|
||||
RhsStorageOrder = TriStorageOrder,
|
||||
SmallPanelWidth = EIGEN_ENUM_MAX(Blocking::mr,Blocking::nr),
|
||||
IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular
|
||||
};
|
||||
|
||||
int kc = std::min<int>(/*Blocking::Max_kc/4*/32,size); // cache block size along the K direction
|
||||
int mc = std::min<int>(/*Blocking::Max_mc*/32,size); // cache block size along the M direction
|
||||
|
||||
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
||||
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*size*Blocking::PacketSize);
|
||||
|
||||
ei_conj_if<Conjugate> conj;
|
||||
ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<false,Conjugate> > gebp_kernel;
|
||||
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder> pack_rhs;
|
||||
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder,true> pack_rhs_panel;
|
||||
ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false, true> pack_lhs_panel;
|
||||
ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false> pack_lhs;
|
||||
|
||||
for(int k2=IsLowerTriangular ? size : 0;
|
||||
IsLowerTriangular ? k2>0 : k2<size;
|
||||
IsLowerTriangular ? k2-=kc : k2+=kc)
|
||||
{
|
||||
const int actual_kc = std::min(IsLowerTriangular ? k2 : size-k2, kc);
|
||||
int actual_k2 = IsLowerTriangular ? k2-actual_kc : k2 ;
|
||||
|
||||
int startPanel = IsLowerTriangular ? 0 : k2+actual_kc;
|
||||
int rs = IsLowerTriangular ? actual_k2 : size - actual_k2 - actual_kc;
|
||||
Scalar* geb = blockB+actual_kc*actual_kc*Blocking::PacketSize;
|
||||
|
||||
if (rs>0) pack_rhs(geb, &rhs(actual_k2,startPanel), triStride, -1, actual_kc, rs);
|
||||
|
||||
// triangular packing (we only pack the panels off the diagonal,
|
||||
// neglecting the blocks overlapping the diagonal
|
||||
{
|
||||
for (int j2=0; j2<actual_kc; j2+=SmallPanelWidth)
|
||||
{
|
||||
int actualPanelWidth = std::min<int>(actual_kc-j2, SmallPanelWidth);
|
||||
int actual_j2 = actual_k2 + j2;
|
||||
int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0;
|
||||
int panelLength = IsLowerTriangular ? actual_kc-j2-actualPanelWidth : j2;
|
||||
|
||||
// std::cerr << "$ " << k2 << " " << j2 << " " << actual_j2 << " " << panelOffset << " " << panelLength << "\n";
|
||||
|
||||
if (panelLength>0)
|
||||
pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize,
|
||||
&rhs(actual_k2+panelOffset, actual_j2), triStride, -1,
|
||||
panelLength, actualPanelWidth,
|
||||
actual_kc, panelOffset);
|
||||
}
|
||||
}
|
||||
|
||||
for(int i2=0; i2<rows; i2+=mc)
|
||||
{
|
||||
const int actual_mc = std::min(mc,rows-i2);
|
||||
|
||||
// triangular solver kernel
|
||||
{
|
||||
// for each small block of the diagonal (=> vertical panels of rhs)
|
||||
for (int j2 = IsLowerTriangular
|
||||
? (actual_kc - ((actual_kc%SmallPanelWidth) ? (actual_kc%SmallPanelWidth)
|
||||
: SmallPanelWidth))
|
||||
: 0;
|
||||
IsLowerTriangular ? j2>=0 : j2<actual_kc;
|
||||
IsLowerTriangular ? j2-=SmallPanelWidth : j2+=SmallPanelWidth)
|
||||
{
|
||||
int actualPanelWidth = std::min<int>(actual_kc-j2, SmallPanelWidth);
|
||||
int absolute_j2 = actual_k2 + j2;
|
||||
int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0;
|
||||
int panelLength = IsLowerTriangular ? actual_kc - j2 - actualPanelWidth : j2;
|
||||
|
||||
// GEBP
|
||||
//if (lengthTarget>0)
|
||||
if(panelLength>0)
|
||||
{
|
||||
gebp_kernel(&lhs(i2,absolute_j2), otherStride,
|
||||
blockA, blockB+j2*actual_kc*Blocking::PacketSize,
|
||||
actual_mc, panelLength, actualPanelWidth,
|
||||
actual_kc, actual_kc, // strides
|
||||
panelOffset, panelOffset*Blocking::PacketSize); // offsets
|
||||
}
|
||||
|
||||
// unblocked triangular solve
|
||||
for (int k=0; k<actualPanelWidth; ++k)
|
||||
{
|
||||
int j = IsLowerTriangular ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
|
||||
|
||||
Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(rhs(j,j));
|
||||
for (int i=0; i<actual_mc; ++i)
|
||||
{
|
||||
int absolute_i = i2+i;
|
||||
Scalar b = 0;
|
||||
for (int k3=0; k3<k; ++k3)
|
||||
if(IsLowerTriangular)
|
||||
b += lhs(absolute_i,j+1+k3) * conj(rhs(j+1+k3,j));
|
||||
else
|
||||
b += lhs(absolute_i,absolute_j2+k3) * conj(rhs(absolute_j2+k3,j));
|
||||
lhs(absolute_i,j) = (lhs(absolute_i,j) - b)*a;
|
||||
}
|
||||
}
|
||||
|
||||
// pack the just computed part of lhs to A
|
||||
pack_lhs_panel(blockA, _other+absolute_j2*otherStride+i2, otherStride,
|
||||
actualPanelWidth, actual_mc,
|
||||
actual_kc, j2);
|
||||
}
|
||||
}
|
||||
|
||||
if (rs>0)
|
||||
gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb,
|
||||
actual_mc, actual_kc, rs);
|
||||
}
|
||||
}
|
||||
|
||||
ei_aligned_stack_delete(Scalar, blockA, kc*mc);
|
||||
ei_aligned_stack_delete(Scalar, blockB, kc*size*Blocking::PacketSize);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_H
|
||||
|
@ -35,7 +35,7 @@ struct ei_gebp_kernel;
|
||||
template<typename Scalar, int nr, int StorageOrder, bool PanelMode=false>
|
||||
struct ei_gemm_pack_rhs;
|
||||
|
||||
template<typename Scalar, int mr, int StorageOrder, bool Conjugate = false>
|
||||
template<typename Scalar, int mr, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
|
||||
struct ei_gemm_pack_lhs;
|
||||
|
||||
template<
|
||||
|
@ -233,6 +233,12 @@ enum {
|
||||
DontAlign = 0x2
|
||||
};
|
||||
|
||||
// used for the solvers
|
||||
enum {
|
||||
OnTheLeft = 1,
|
||||
OnTheRight = 2
|
||||
};
|
||||
|
||||
/* the following could as well be written:
|
||||
* enum NoChange_t { NoChange };
|
||||
* but it feels dangerous to disambiguate overloaded functions on enum/integer types.
|
||||
|
Loading…
Reference in New Issue
Block a user