Unwind Block of Blocks

This commit is contained in:
Charles Schlosser 2023-08-29 17:21:41 +00:00 committed by Rasmus Munk Larsen
parent 81b48065ea
commit 18018ed013
3 changed files with 110 additions and 5 deletions

View File

@ -17,9 +17,10 @@
namespace Eigen {
namespace internal {
template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel>
struct traits<Block<XprType, BlockRows, BlockCols, InnerPanel> > : traits<XprType>
template<typename XprType_, int BlockRows, int BlockCols, bool InnerPanel_>
struct traits<Block<XprType_, BlockRows, BlockCols, InnerPanel_> > : traits<XprType_>
{
typedef XprType_ XprType;
typedef typename traits<XprType>::Scalar Scalar;
typedef typename traits<XprType>::StorageKind StorageKind;
typedef typename traits<XprType>::XprKind XprKind;
@ -53,12 +54,13 @@ struct traits<Block<XprType, BlockRows, BlockCols, InnerPanel> > : traits<XprTyp
// FIXME, this traits is rather specialized for dense object and it needs to be cleaned further
FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0,
Flags = (traits<XprType>::Flags & (DirectAccessBit | (InnerPanel?CompressedAccessBit:0))) | FlagsLvalueBit | FlagsRowMajorBit,
Flags = (traits<XprType>::Flags & (DirectAccessBit | (InnerPanel_?CompressedAccessBit:0))) | FlagsLvalueBit | FlagsRowMajorBit,
// FIXME DirectAccessBit should not be handled by expressions
//
// Alignment is needed by MapBase's assertions
// We can sefely set it to false here. Internal alignment errors will be detected by an eigen_internal_assert in the respective evaluator
Alignment = 0
Alignment = 0,
InnerPanel = InnerPanel_ ? 1 : 0
};
};
@ -107,6 +109,7 @@ template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel> class
: public BlockImpl<XprType, BlockRows, BlockCols, InnerPanel, typename internal::traits<XprType>::StorageKind>
{
typedef BlockImpl<XprType, BlockRows, BlockCols, InnerPanel, typename internal::traits<XprType>::StorageKind> Impl;
using BlockHelper = internal::block_xpr_helper<Block>;
public:
//typedef typename Impl::Base Base;
typedef Impl Base;
@ -149,9 +152,25 @@ template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel> class
eigen_assert(startRow >= 0 && blockRows >= 0 && startRow <= xpr.rows() - blockRows
&& startCol >= 0 && blockCols >= 0 && startCol <= xpr.cols() - blockCols);
}
// convert nested blocks (e.g. Block<Block<MatrixType>>) to a simple block expression (Block<MatrixType>)
using ConstUnwindReturnType = Block<const typename BlockHelper::BaseType, BlockRows, BlockCols, InnerPanel>;
using UnwindReturnType = Block<typename BlockHelper::BaseType, BlockRows, BlockCols, InnerPanel>;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ConstUnwindReturnType unwind() const {
return ConstUnwindReturnType(BlockHelper::base(*this), BlockHelper::row(*this, 0), BlockHelper::col(*this, 0),
this->rows(), this->cols());
}
template <typename T = Block, typename EnableIf = std::enable_if_t<!std::is_const<T>::value>>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE UnwindReturnType unwind() {
return UnwindReturnType(BlockHelper::base(*this), BlockHelper::row(*this, 0), BlockHelper::col(*this, 0),
this->rows(), this->cols());
}
};
// The generic default implementation for dense block simplu forward to the internal::BlockImpl_dense
// The generic default implementation for dense block simply forward to the internal::BlockImpl_dense
// that must be specialized for direct and non-direct access...
template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel>
class BlockImpl<XprType, BlockRows, BlockCols, InnerPanel, Dense>

View File

@ -809,6 +809,54 @@ std::string demangle_flags(int f)
}
#endif
template<typename XprType>
struct is_block_xpr : std::false_type {};
template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel>
struct is_block_xpr<Block<XprType, BlockRows, BlockCols, InnerPanel>> : std::true_type {};
template <typename XprType, int BlockRows, int BlockCols, bool InnerPanel>
struct is_block_xpr<const Block<XprType, BlockRows, BlockCols, InnerPanel>> : std::true_type {};
// Helper utility for constructing non-recursive block expressions.
template<typename XprType>
struct block_xpr_helper {
using BaseType = XprType;
// For regular block expressions, simply forward along the InnerPanel argument,
// which is set when calling row/column expressions.
static constexpr bool is_inner_panel(bool inner_panel) { return inner_panel; };
// Only enable non-const base function if XprType is not const (otherwise we get a duplicate definition).
template<typename T = XprType, typename EnableIf=std::enable_if_t<!std::is_const<T>::value>>
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BaseType& base(XprType& xpr) { return xpr; }
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const BaseType& base(const XprType& xpr) { return xpr; }
static constexpr EIGEN_ALWAYS_INLINE Index row(const XprType& /*xpr*/, Index r) { return r; }
static constexpr EIGEN_ALWAYS_INLINE Index col(const XprType& /*xpr*/, Index c) { return c; }
};
template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel>
struct block_xpr_helper<Block<XprType, BlockRows, BlockCols, InnerPanel>> {
using BlockXprType = Block<XprType, BlockRows, BlockCols, InnerPanel>;
// Recursive helper in case of explicit block-of-block expression.
using NestedXprHelper = block_xpr_helper<XprType>;
using BaseType = typename NestedXprHelper::BaseType;
// For block-of-block expressions, we need to combine the InnerPannel trait
// with that of the block subexpression.
static constexpr bool is_inner_panel(bool inner_panel) { return InnerPanel && inner_panel; }
// Only enable non-const base function if XprType is not const (otherwise we get a duplicates definition).
template<typename T = XprType, typename EnableIf=std::enable_if_t<!std::is_const<T>::value>>
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BaseType& base(BlockXprType& xpr) { return NestedXprHelper::base(xpr.nestedExpression()); }
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const BaseType& base(const BlockXprType& xpr) { return NestedXprHelper::base(xpr.nestedExpression()); }
static constexpr EIGEN_ALWAYS_INLINE Index row(const BlockXprType& xpr, Index r) { return xpr.startRow() + NestedXprHelper::row(xpr.nestedExpression(), r); }
static constexpr EIGEN_ALWAYS_INLINE Index col(const BlockXprType& xpr, Index c) { return xpr.startCol() + NestedXprHelper::col(xpr.nestedExpression(), c); }
};
template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel>
struct block_xpr_helper<const Block<XprType, BlockRows, BlockCols, InnerPanel>> : block_xpr_helper<Block<XprType, BlockRows, BlockCols, InnerPanel>> {};
} // end namespace internal

View File

@ -306,6 +306,43 @@ void data_and_stride(const MatrixType& m)
compare_using_data_and_stride(m1.col(c1).transpose());
}
template <typename BaseXpr, typename Xpr = BaseXpr, int Depth = 0>
struct unwind_test_impl {
static void run(Xpr& xpr) {
Index startRow = internal::random<Index>(0, xpr.rows() / 5);
Index startCol = internal::random<Index>(0, xpr.cols() / 6);
Index rows = xpr.rows() / 3;
Index cols = xpr.cols() / 2;
// test equivalence of const expressions
const Block<const Xpr> constNestedBlock(xpr, startRow, startCol, rows, cols);
const Block<const BaseXpr> constUnwoundBlock = constNestedBlock.unwind();
VERIFY_IS_CWISE_EQUAL(constNestedBlock, constUnwoundBlock);
// modify a random element in each representation and test equivalence of non-const expressions
Block<Xpr> nestedBlock(xpr, startRow, startCol, rows, cols);
Block<BaseXpr> unwoundBlock = nestedBlock.unwind();
Index r1 = internal::random<Index>(0, rows - 1);
Index c1 = internal::random<Index>(0, cols - 1);
Index r2 = internal::random<Index>(0, rows - 1);
Index c2 = internal::random<Index>(0, cols - 1);
nestedBlock.coeffRef(r1, c1) = internal::random<typename DenseBase<Xpr>::Scalar>();
unwoundBlock.coeffRef(r2, c2) = internal::random<typename DenseBase<Xpr>::Scalar>();
VERIFY_IS_CWISE_EQUAL(nestedBlock, unwoundBlock);
unwind_test_impl<BaseXpr, Block<Xpr>, Depth + 1>::run(nestedBlock);
}
};
template <typename BaseXpr, typename Xpr>
struct unwind_test_impl<BaseXpr, Xpr, 4> {
static void run(const Xpr&) {}
};
template <typename BaseXpr>
void unwind_test(const BaseXpr&) {
BaseXpr xpr = BaseXpr::Random(100, 100);
unwind_test_impl<BaseXpr>::run(xpr);
}
EIGEN_DECLARE_TEST(block)
{
for(int i = 0; i < g_repeat; i++) {
@ -320,6 +357,7 @@ EIGEN_DECLARE_TEST(block)
CALL_SUBTEST_7( block(Matrix<int,Dynamic,Dynamic,RowMajor>(internal::random(2,50), internal::random(2,50))) );
CALL_SUBTEST_8( block(Matrix<float,Dynamic,4>(3, 4)) );
CALL_SUBTEST_9( unwind_test(MatrixXf()));
#ifndef EIGEN_DEFAULT_TO_ROW_MAJOR
CALL_SUBTEST_6( data_and_stride(MatrixXf(internal::random(5,50), internal::random(5,50))) );