diff --git a/src/MatrixBase.h b/src/MatrixBase.h index bcc747aca..860b9272d 100644 --- a/src/MatrixBase.h +++ b/src/MatrixBase.h @@ -67,6 +67,7 @@ template class MatrixRef { public: typedef typename ForwardDecl::Scalar Scalar; + typedef MatrixXpr > Xpr; MatrixRef(MatrixType& matrix) : m_matrix(matrix) {} MatrixRef(const MatrixRef& other) : m_matrix(other.m_matrix) {} @@ -92,6 +93,11 @@ template class MatrixRef MatrixType& matrix() { return m_matrix; } + Xpr xpr() + { + return Xpr(*this); + } + protected: MatrixType& m_matrix; }; @@ -106,6 +112,7 @@ class MatrixBase typedef MatrixRef > Ref; typedef MatrixConstXpr ConstXpr; typedef MatrixXpr Xpr; + typedef MatrixAlias Alias; Ref ref() { @@ -127,6 +134,8 @@ class MatrixBase return ConstXpr(constRef()); } + Alias alias(); + static bool hasDynamicNumRows() { return Derived::_hasDynamicNumRows(); @@ -175,29 +184,16 @@ class MatrixBase } template - void operator=(const MatrixConstXpr &xpr) + MatrixBase& operator=(const MatrixConstXpr &otherXpr) { - resize(xpr.rows(), xpr.cols()); - for(int i = 0; i < rows(); i++) - for(int j = 0; j < cols(); j++) - this->operator()(i, j) = xpr(i, j); + resize(otherXpr.rows(), otherXpr.cols()); + xpr() = otherXpr; + return *this; } - void operator=(const MatrixBase &other) + MatrixBase& operator=(const MatrixBase &other) { - resize(other.rows(), other.cols()); - for(int i = 0; i < rows(); i++) - for(int j = 0; j < cols(); j++) - this->operator()(i, j) = other(i, j); - } - - template - void operator<<(const MatrixConstXpr &xpr) - { - Derived tmp(xpr.rows(), xpr.cols()); - MatrixBase *ptr = static_cast(&tmp); - *ptr = xpr; - *this = *ptr; + return *this = other.constXpr(); } MatrixConstXpr > row(int i) const; @@ -213,12 +209,13 @@ class MatrixBase template template -void MatrixXpr::operator=(const MatrixBase& matrix) +MatrixXpr& MatrixXpr::operator=(const MatrixBase& matrix) { assert(rows() == matrix.rows() && cols() == matrix.cols()); for(int i = 0; i < rows(); i++) for(int j = 0; j < cols(); j++) this->operator()(i, j) = matrix(i, j); + return *this; } template @@ -252,6 +249,67 @@ std::ostream & operator << (std::ostream & s, return s; } +template class MatrixAlias +{ + public: + typedef typename Derived::Scalar Scalar; + typedef MatrixRef > Ref; + typedef MatrixXpr Xpr; + + MatrixAlias(Derived& matrix) : m_ref(matrix), m_tmp(matrix) {} + MatrixAlias(const MatrixAlias& other) : m_ref(other.m_ref), m_tmp(other.m_tmp) {} + + ~MatrixAlias() + { + m_ref.xpr() = m_tmp; + } + + Xpr xpr() + { + return Xpr(ref()); + } + + static bool hasDynamicNumRows() + { + return MatrixBase::hasDynamicNumRows(); + } + + static bool hasDynamicNumCols() + { + return MatrixBase::hasDynamicNumCols(); + } + + int rows() const { return m_tmp.rows(); } + int cols() const { return m_tmp.cols(); } + + Scalar& operator()(int row, int col) + { + return m_tmp(row, col); + } + + Ref ref() + { + return Ref(*this); + } + + template + void operator=(const MatrixConstXpr &xpr) + { + ref().xpr() = xpr; + } + + protected: + MatrixRef > m_ref; + Derived m_tmp; +}; + +template +typename MatrixBase::Alias +MatrixBase::alias() +{ + return Alias(*static_cast(this)); +} + } // namespace Eigen #endif // EIGEN_MATRIXBASE_H diff --git a/src/MatrixXpr.h b/src/MatrixXpr.h index 9a1ab1085..937d4e280 100644 --- a/src/MatrixXpr.h +++ b/src/MatrixXpr.h @@ -97,16 +97,17 @@ template class MatrixXpr } template - void operator=(const MatrixConstXpr &other) + MatrixXpr& operator=(const MatrixConstXpr &other) { assert(rows() == other.rows() && cols() == other.cols()); for(int i = 0; i < rows(); i++) for(int j = 0; j < cols(); j++) this->operator()(i, j) = other(i, j); + return *this; } template - void operator=(const MatrixBase& matrix); + MatrixXpr& operator=(const MatrixBase& matrix); MatrixXpr > > row(int i); MatrixXpr > > col(int i); diff --git a/src/ScalarOps.h b/src/ScalarOps.h index 26698c751..5d0a16ec9 100644 --- a/src/ScalarOps.h +++ b/src/ScalarOps.h @@ -116,6 +116,30 @@ operator *(typename Derived::Scalar scalar, return XprType(ProductType(matrix.constRef(), scalar)); } +template +const MatrixConstXpr< + const ScalarProduct< + MatrixConstXpr + > +> +operator /(const MatrixConstXpr& xpr, + typename Content::Scalar scalar) +{ + return xpr * (static_cast(1) / scalar); +} + +template +const MatrixConstXpr< + const ScalarProduct< + MatrixConstRef > + > +> +operator /(const MatrixBase& matrix, + typename Derived::Scalar scalar) +{ + return matrix * (static_cast(1) / scalar); +} + } // namespace Eigen #endif // EIGEN_SCALAROPS_H diff --git a/src/Util.h b/src/Util.h index b789b891e..f1b51bbe1 100644 --- a/src/Util.h +++ b/src/Util.h @@ -48,6 +48,7 @@ template class MatrixX; template class Vector; template class VectorX; template class MatrixBase; +template class MatrixAlias; template struct ForwardDecl; template struct ForwardDecl< Matrix > @@ -60,12 +61,20 @@ template struct ForwardDecl< VectorX > { typedef T Scalar; }; template struct ForwardDecl< MatrixBase > > { typedef T Scalar; }; +template struct ForwardDecl< MatrixAlias > > +{ typedef T Scalar; }; template struct ForwardDecl< MatrixBase > > { typedef T Scalar; }; +template struct ForwardDecl< MatrixAlias > > +{ typedef T Scalar; }; template struct ForwardDecl< MatrixBase > > { typedef T Scalar; }; +template struct ForwardDecl< MatrixAlias > > +{ typedef T Scalar; }; template struct ForwardDecl< MatrixBase > > { typedef T Scalar; }; +template struct ForwardDecl< MatrixAlias > > +{ typedef T Scalar; }; template class MatrixRef; diff --git a/test/matrixmanip.cpp b/test/matrixmanip.cpp index 153af365c..5ce8fa615 100644 --- a/test/matrixmanip.cpp +++ b/test/matrixmanip.cpp @@ -37,6 +37,7 @@ template void matrixManip(const MatrixType& m) a.block(1, rows-1, 1, cols-1); a.xpr().row(i) = b.row(i); a.xpr().minor(i, j) = b.block(1, rows-1, 1, cols-1); + a.alias().xpr().minor(i, j) = a.block(1, rows-1, 1, cols-1); } void EigenTest::testMatrixManip() diff --git a/test/matrixops.cpp b/test/matrixops.cpp index 6e95f7893..a9034dcf2 100644 --- a/test/matrixops.cpp +++ b/test/matrixops.cpp @@ -44,7 +44,7 @@ template void vectorOps(const VectorType& v) a = b; a = b + c; a = s * (b - c); - a << a + b; + a.alias() = a + b; } void EigenTest::testVectorOps()