diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index 671b2458d..f60ef1c03 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -27,6 +27,7 @@ template -struct ei_triangular_solver_selector +struct ei_triangular_solver_selector { typedef typename Rhs::Scalar Scalar; typedef ei_blas_traits LhsProductTraits; @@ -89,7 +90,7 @@ struct ei_triangular_solver_selector // forward and backward substitution, column-major, rhs is a vector template -struct ei_triangular_solver_selector +struct ei_triangular_solver_selector { typedef typename Rhs::Scalar Scalar; typedef typename ei_packet_traits::type Packet; @@ -142,12 +143,12 @@ struct ei_triangular_solver_selector } }; -template +template struct ei_triangular_solve_matrix; // the rhs is a matrix -template -struct ei_triangular_solver_selector +template +struct ei_triangular_solver_selector { typedef typename Rhs::Scalar Scalar; typedef ei_blas_traits LhsProductTraits; @@ -155,7 +156,8 @@ struct ei_triangular_solver_selector + ei_triangular_solve_matrix ::run(lhs.rows(), rhs.cols(), &actualLhs.coeff(0,0), actualLhs.stride(), &rhs.coeffRef(0,0), rhs.stride()); } }; @@ -194,7 +196,7 @@ struct ei_triangular_solver_unroller { }; template -struct ei_triangular_solver_selector { +struct ei_triangular_solver_selector { static void run(const Lhs& lhs, Rhs& rhs) { ei_triangular_solver_unroller::run(lhs,rhs); } }; @@ -213,7 +215,7 @@ struct ei_triangular_solver_selector -template +template void TriangularView::solveInPlace(const MatrixBase& _rhs) const { RhsDerived& rhs = _rhs.const_cast_derived(); @@ -228,7 +230,7 @@ void TriangularView::solveInPlace(const MatrixBase& RhsCopy rhsCopy(rhs); ei_triangular_solver_selector::type, - Mode>::run(_expression(), rhsCopy); + Side, Mode>::run(_expression(), rhsCopy); if (copy) rhs = rhsCopy; @@ -266,12 +268,12 @@ void TriangularView::solveInPlace(const MatrixBase& * \sa TriangularView::solveInPlace() */ template -template +template typename ei_plain_matrix_type_column_major::type TriangularView::solve(const MatrixBase& rhs) const { typename ei_plain_matrix_type_column_major::type res(rhs); - solveInPlace(res); + solveInPlace(res); return res; } diff --git a/Eigen/src/Core/TriangularMatrix.h b/Eigen/src/Core/TriangularMatrix.h index 861b738cb..a41adb190 100644 --- a/Eigen/src/Core/TriangularMatrix.h +++ b/Eigen/src/Core/TriangularMatrix.h @@ -269,13 +269,23 @@ template class TriangularView (lhs.derived(),rhs.m_matrix); } - template + + template typename ei_plain_matrix_type_column_major::type solve(const MatrixBase& other) const; - template + template void solveInPlace(const MatrixBase& other) const; + template + typename ei_plain_matrix_type_column_major::type + solve(const MatrixBase& other) const + { return solve(other); } + + template + void solveInPlace(const MatrixBase& other) const + { return solveInPlace(other); } + template void swap(const TriangularBase& other) { diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 66fee793c..a9caa9cd3 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -347,23 +347,31 @@ struct ei_gebp_kernel // // 32 33 34 35 ... // 36 36 38 39 ... -template +template struct ei_gemm_pack_lhs { - void operator()(Scalar* blockA, const EIGEN_RESTRICT Scalar* _lhs, int lhsStride, int depth, int rows) + void operator()(Scalar* blockA, const EIGEN_RESTRICT Scalar* _lhs, int lhsStride, int depth, int rows, + int stride=0, int offset=0) { + ei_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); ei_conj_if::IsComplex && Conjugate> cj; ei_const_blas_data_mapper lhs(_lhs,lhsStride); int count = 0; const int peeled_mc = (rows/mr)*mr; for(int i=0; i -struct ei_triangular_solve_matrix +template +struct ei_triangular_solve_matrix { static EIGEN_DONT_INLINE void run( int size, int cols, - const Scalar* lhs, int lhsStride, - Scalar* _rhs, int rhsStride) + const Scalar* tri, int triStride, + Scalar* _other, int otherStride) { - Map > rhs(_rhs, rhsStride, cols); - Matrix aux = rhs.block(0,0,size,cols); - ei_triangular_solve_matrix - ::run(size, cols, lhs, lhsStride, aux.data(), aux.stride()); - rhs.block(0,0,size,cols) = aux; + + ei_triangular_solve_matrix< + Scalar, Side==OnTheLeft?OnTheRight:OnTheLeft, + (Mode&UnitDiagBit) | (Mode&UpperTriangular) ? LowerTriangular : UpperTriangular, + !Conjugate, TriStorageOrder, ColMajor> + ::run(size, cols, tri, triStride, _other, otherStride); + +// Map > other(_other, otherStride, cols); +// Matrix aux = other.block(0,0,size,cols); +// ei_triangular_solve_matrix +// ::run(size, cols, tri, triStride, aux.data(), aux.stride()); +// other.block(0,0,size,cols) = aux; } }; /* Optimized triangular solver with multiple right hand side (_TRSM) */ -template -struct ei_triangular_solve_matrix +template +struct ei_triangular_solve_matrix { static EIGEN_DONT_INLINE void run( - int size, int cols, - const Scalar* _lhs, int lhsStride, - Scalar* _rhs, int rhsStride) + int size, int otherSize, + const Scalar* _tri, int triStride, + Scalar* _other, int otherStride) { - ei_const_blas_data_mapper lhs(_lhs,lhsStride); - ei_blas_data_mapper rhs(_rhs,rhsStride); + int cols = otherSize; + ei_const_blas_data_mapper tri(_tri,triStride); + ei_blas_data_mapper other(_other,otherStride); typedef ei_product_blocking_traits Blocking; enum { @@ -67,9 +75,9 @@ struct ei_triangular_solve_matrix conj; - ei_gebp_kernel > gebp_kernel; - ei_gemm_pack_lhs pack_lhs; + ei_conj_if conj; + ei_gebp_kernel > gebp_kernel; + ei_gemm_pack_lhs pack_lhs; for(int k2=IsLowerTriangular ? 0 : size; IsLowerTriangular ? k20; @@ -103,25 +111,25 @@ struct ei_triangular_solve_matrix() - (blockB, _rhs+startBlock, rhsStride, -1, actualPanelWidth, cols, actual_kc, blockBOffset); + (blockB, _other+startBlock, otherStride, -1, actualPanelWidth, cols, actual_kc, blockBOffset); // GEBP if (lengthTarget>0) { int startTarget = IsLowerTriangular ? k2+k1+actualPanelWidth : k2-actual_kc; - pack_lhs(blockA, &lhs(startTarget,startBlock), lhsStride, actualPanelWidth, lengthTarget); + pack_lhs(blockA, &tri(startTarget,startBlock), triStride, actualPanelWidth, lengthTarget); - gebp_kernel(_rhs+startTarget, rhsStride, blockA, blockB, lengthTarget, actualPanelWidth, cols, + gebp_kernel(_other+startTarget, otherStride, blockA, blockB, lengthTarget, actualPanelWidth, cols, actualPanelWidth, actual_kc, 0, blockBOffset*Blocking::PacketSize); } } @@ -158,9 +166,9 @@ struct ei_triangular_solve_matrix0) { - pack_lhs(blockA, &lhs(i2, IsLowerTriangular ? k2 : k2-kc), lhsStride, actual_kc, actual_mc); + pack_lhs(blockA, &tri(i2, IsLowerTriangular ? k2 : k2-kc), triStride, actual_kc, actual_mc); - gebp_kernel(_rhs+i2, rhsStride, blockA, blockB, actual_mc, actual_kc, cols); + gebp_kernel(_other+i2, otherStride, blockA, blockB, actual_mc, actual_kc, cols); } } } @@ -171,4 +179,141 @@ struct ei_triangular_solve_matrix +struct ei_triangular_solve_matrix +{ + static EIGEN_DONT_INLINE void run( + int size, int otherSize, + const Scalar* _tri, int triStride, + Scalar* _other, int otherStride) + { + int rows = otherSize; +// ei_const_blas_data_mapper rhs(_tri,triStride); +// ei_blas_data_mapper lhs(_other,otherStride); + + Map > rhs(_tri,size,size); + Map > lhs(_other,rows,size); + + typedef ei_product_blocking_traits Blocking; + enum { + RhsStorageOrder = TriStorageOrder, + SmallPanelWidth = EIGEN_ENUM_MAX(Blocking::mr,Blocking::nr), + IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular + }; + + int kc = std::min(/*Blocking::Max_kc/4*/32,size); // cache block size along the K direction + int mc = std::min(/*Blocking::Max_mc*/32,size); // cache block size along the M direction + + Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); + Scalar* blockB = ei_aligned_stack_new(Scalar, kc*size*Blocking::PacketSize); + + ei_conj_if conj; + ei_gebp_kernel > gebp_kernel; + ei_gemm_pack_rhs pack_rhs; + ei_gemm_pack_rhs pack_rhs_panel; + ei_gemm_pack_lhs pack_lhs_panel; + ei_gemm_pack_lhs pack_lhs; + + for(int k2=IsLowerTriangular ? size : 0; + IsLowerTriangular ? k2>0 : k20) pack_rhs(geb, &rhs(actual_k2,startPanel), triStride, -1, actual_kc, rs); + + // triangular packing (we only pack the panels off the diagonal, + // neglecting the blocks overlapping the diagonal + { + for (int j2=0; j2(actual_kc-j2, SmallPanelWidth); + int actual_j2 = actual_k2 + j2; + int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0; + int panelLength = IsLowerTriangular ? actual_kc-j2-actualPanelWidth : j2; + +// std::cerr << "$ " << k2 << " " << j2 << " " << actual_j2 << " " << panelOffset << " " << panelLength << "\n"; + + if (panelLength>0) + pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize, + &rhs(actual_k2+panelOffset, actual_j2), triStride, -1, + panelLength, actualPanelWidth, + actual_kc, panelOffset); + } + } + + for(int i2=0; i2 vertical panels of rhs) + for (int j2 = IsLowerTriangular + ? (actual_kc - ((actual_kc%SmallPanelWidth) ? (actual_kc%SmallPanelWidth) + : SmallPanelWidth)) + : 0; + IsLowerTriangular ? j2>=0 : j2(actual_kc-j2, SmallPanelWidth); + int absolute_j2 = actual_k2 + j2; + int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0; + int panelLength = IsLowerTriangular ? actual_kc - j2 - actualPanelWidth : j2; + + // GEBP + //if (lengthTarget>0) + if(panelLength>0) + { + gebp_kernel(&lhs(i2,absolute_j2), otherStride, + blockA, blockB+j2*actual_kc*Blocking::PacketSize, + actual_mc, panelLength, actualPanelWidth, + actual_kc, actual_kc, // strides + panelOffset, panelOffset*Blocking::PacketSize); // offsets + } + + // unblocked triangular solve + for (int k=0; k0) + gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb, + actual_mc, actual_kc, rs); + } + } + + ei_aligned_stack_delete(Scalar, blockA, kc*mc); + ei_aligned_stack_delete(Scalar, blockB, kc*size*Blocking::PacketSize); + } +}; + #endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_H diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index c1b80acff..216f2dd69 100644 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -35,7 +35,7 @@ struct ei_gebp_kernel; template struct ei_gemm_pack_rhs; -template +template struct ei_gemm_pack_lhs; template< diff --git a/Eigen/src/Core/util/Constants.h b/Eigen/src/Core/util/Constants.h index 0c251022b..98b3bd2e4 100644 --- a/Eigen/src/Core/util/Constants.h +++ b/Eigen/src/Core/util/Constants.h @@ -233,6 +233,12 @@ enum { DontAlign = 0x2 }; +// used for the solvers +enum { + OnTheLeft = 1, + OnTheRight = 2 +}; + /* the following could as well be written: * enum NoChange_t { NoChange }; * but it feels dangerous to disambiguate overloaded functions on enum/integer types.