mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-01 18:26:24 +08:00
avoid dynamic allocation for fixed size triangular solving
This commit is contained in:
parent
bc580bbffb
commit
924c7a9300
@ -100,12 +100,22 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic>
|
||||
typedef typename Rhs::Index Index;
|
||||
typedef blas_traits<Lhs> LhsProductTraits;
|
||||
typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType;
|
||||
|
||||
static void run(const Lhs& lhs, Rhs& rhs)
|
||||
{
|
||||
typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsProductTraits::extract(lhs);
|
||||
|
||||
const Index size = lhs.rows();
|
||||
const Index othersize = Side==OnTheLeft? rhs.cols() : rhs.rows();
|
||||
|
||||
typedef internal::gemm_blocking_space<(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar,
|
||||
Rhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxRowsAtCompileTime,4> BlockingType;
|
||||
|
||||
BlockingType blocking(rhs.rows(), rhs.cols(), size);
|
||||
|
||||
triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
|
||||
(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
||||
::run(lhs.rows(), Side==OnTheLeft? rhs.cols() : rhs.rows(), &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride());
|
||||
::run(size, othersize, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride(), blocking);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -79,7 +79,7 @@ static void run(Index rows, Index cols, Index depth,
|
||||
|
||||
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
|
||||
|
||||
Index kc = blocking.kc(); // cache block size along the K direction
|
||||
Index kc = blocking.kc(); // cache block size along the K direction
|
||||
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
|
||||
//Index nc = blocking.nc(); // cache block size along the N direction
|
||||
|
||||
@ -249,7 +249,7 @@ struct gemm_functor
|
||||
BlockingType& m_blocking;
|
||||
};
|
||||
|
||||
template<int StorageOrder, typename LhsScalar, typename RhsScalar, int MaxRows, int MaxCols, int MaxDepth,
|
||||
template<int StorageOrder, typename LhsScalar, typename RhsScalar, int MaxRows, int MaxCols, int MaxDepth, int KcFactor=1,
|
||||
bool FiniteAtCompileTime = MaxRows!=Dynamic && MaxCols!=Dynamic && MaxDepth != Dynamic> class gemm_blocking_space;
|
||||
|
||||
template<typename _LhsScalar, typename _RhsScalar>
|
||||
@ -282,8 +282,8 @@ class level3_blocking
|
||||
inline RhsScalar* blockW() { return m_blockW; }
|
||||
};
|
||||
|
||||
template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth>
|
||||
class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, true>
|
||||
template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
|
||||
class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, KcFactor, true>
|
||||
: public level3_blocking<
|
||||
typename conditional<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::type,
|
||||
typename conditional<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::type>
|
||||
@ -324,8 +324,8 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M
|
||||
inline void allocateAll() {}
|
||||
};
|
||||
|
||||
template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth>
|
||||
class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, false>
|
||||
template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
|
||||
class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, KcFactor, false>
|
||||
: public level3_blocking<
|
||||
typename conditional<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::type,
|
||||
typename conditional<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::type>
|
||||
@ -349,7 +349,7 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M
|
||||
this->m_nc = Transpose ? rows : cols;
|
||||
this->m_kc = depth;
|
||||
|
||||
computeProductBlockingSizes<LhsScalar,RhsScalar>(this->m_kc, this->m_mc, this->m_nc);
|
||||
computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc, this->m_mc, this->m_nc);
|
||||
m_sizeA = this->m_mc * this->m_kc;
|
||||
m_sizeB = this->m_kc * this->m_nc;
|
||||
m_sizeW = this->m_kc*Traits::WorkSpaceFactor;
|
||||
|
@ -36,14 +36,15 @@ struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,
|
||||
static EIGEN_DONT_INLINE void run(
|
||||
Index size, Index cols,
|
||||
const Scalar* tri, Index triStride,
|
||||
Scalar* _other, Index otherStride)
|
||||
Scalar* _other, Index otherStride,
|
||||
level3_blocking<Scalar,Scalar>& blocking)
|
||||
{
|
||||
triangular_solve_matrix<
|
||||
Scalar, Index, Side==OnTheLeft?OnTheRight:OnTheLeft,
|
||||
(Mode&UnitDiag) | ((Mode&Upper) ? Lower : Upper),
|
||||
NumTraits<Scalar>::IsComplex && Conjugate,
|
||||
TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor>
|
||||
::run(size, cols, tri, triStride, _other, otherStride);
|
||||
::run(size, cols, tri, triStride, _other, otherStride, blocking);
|
||||
}
|
||||
};
|
||||
|
||||
@ -55,7 +56,8 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO
|
||||
static EIGEN_DONT_INLINE void run(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherStride)
|
||||
Scalar* _other, Index otherStride,
|
||||
level3_blocking<Scalar,Scalar>& blocking)
|
||||
{
|
||||
Index cols = otherSize;
|
||||
const_blas_data_mapper<Scalar, Index, TriStorageOrder> tri(_tri,triStride);
|
||||
@ -67,17 +69,16 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO
|
||||
IsLower = (Mode&Lower) == Lower
|
||||
};
|
||||
|
||||
Index kc = size; // cache block size along the K direction
|
||||
Index mc = size; // cache block size along the M direction
|
||||
Index nc = cols; // cache block size along the N direction
|
||||
computeProductBlockingSizes<Scalar,Scalar,4>(kc, mc, nc);
|
||||
Index kc = blocking.kc(); // cache block size along the K direction
|
||||
Index mc = (std::min)(size,blocking.mc()); // cache block size along the M direction
|
||||
|
||||
std::size_t sizeA = kc*mc;
|
||||
std::size_t sizeB = kc*cols;
|
||||
std::size_t sizeW = kc*Traits::WorkSpaceFactor;
|
||||
std::size_t sizeB = sizeW + kc*cols;
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, kc*mc, 0);
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, allocatedBlockB, sizeB, 0);
|
||||
Scalar* blockB = allocatedBlockB + sizeW;
|
||||
Scalar* blockW = allocatedBlockB;
|
||||
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW());
|
||||
|
||||
conj_if<Conjugate> conj;
|
||||
gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, Conjugate, false> gebp_kernel;
|
||||
@ -181,7 +182,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO
|
||||
{
|
||||
pack_lhs(blockA, &tri(i2, IsLower ? k2 : k2-kc), triStride, actual_kc, actual_mc);
|
||||
|
||||
gebp_kernel(_other+i2, otherStride, blockA, blockB, actual_mc, actual_kc, cols, Scalar(-1));
|
||||
gebp_kernel(_other+i2, otherStride, blockA, blockB, actual_mc, actual_kc, cols, Scalar(-1), -1, -1, 0, 0, blockW);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -197,7 +198,8 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
|
||||
static EIGEN_DONT_INLINE void run(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherStride)
|
||||
Scalar* _other, Index otherStride,
|
||||
level3_blocking<Scalar,Scalar>& blocking)
|
||||
{
|
||||
Index rows = otherSize;
|
||||
const_blas_data_mapper<Scalar, Index, TriStorageOrder> rhs(_tri,triStride);
|
||||
@ -210,19 +212,16 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
|
||||
IsLower = (Mode&Lower) == Lower
|
||||
};
|
||||
|
||||
// Index kc = std::min<Index>(Traits::Max_kc/4,size); // cache block size along the K direction
|
||||
// Index mc = std::min<Index>(Traits::Max_mc,size); // cache block size along the M direction
|
||||
// check that !!!!
|
||||
Index kc = size; // cache block size along the K direction
|
||||
Index mc = size; // cache block size along the M direction
|
||||
Index nc = rows; // cache block size along the N direction
|
||||
computeProductBlockingSizes<Scalar,Scalar,4>(kc, mc, nc);
|
||||
Index kc = blocking.kc(); // cache block size along the K direction
|
||||
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
|
||||
|
||||
std::size_t sizeA = kc*mc;
|
||||
std::size_t sizeB = kc*size;
|
||||
std::size_t sizeW = kc*Traits::WorkSpaceFactor;
|
||||
std::size_t sizeB = sizeW + kc*size;
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, kc*mc, 0);
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, allocatedBlockB, sizeB, 0);
|
||||
Scalar* blockB = allocatedBlockB + sizeW;
|
||||
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW());
|
||||
|
||||
conj_if<Conjugate> conj;
|
||||
gebp_kernel<Scalar,Scalar, Index, Traits::mr, Traits::nr, false, Conjugate> gebp_kernel;
|
||||
@ -289,7 +288,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
|
||||
Scalar(-1),
|
||||
actual_kc, actual_kc, // strides
|
||||
panelOffset, panelOffset, // offsets
|
||||
allocatedBlockB); // workspace
|
||||
blockW); // workspace
|
||||
}
|
||||
|
||||
// unblocked triangular solve
|
||||
@ -320,7 +319,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
|
||||
if (rs>0)
|
||||
gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb,
|
||||
actual_mc, actual_kc, rs, Scalar(-1),
|
||||
-1, -1, 0, 0, allocatedBlockB);
|
||||
-1, -1, 0, 0, blockW);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user