diff --git a/include/rotgen/algebra/svd/dynamic.hpp b/include/rotgen/algebra/svd/dynamic.hpp index 12d98b5..48b3f0b 100644 --- a/include/rotgen/algebra/svd/dynamic.hpp +++ b/include/rotgen/algebra/svd/dynamic.hpp @@ -25,20 +25,20 @@ namespace rotgen int rank() const { return parent::rank(); } - m_type U() const { return parent::U(); } + m_type matrixU() const { return parent::U(); } - m_type D() const { return parent::D(); } + m_type matrixD() const { return parent::D(); } - m_type V() const { return parent::V(); } + m_type matrixV() const { return parent::V(); } - d_type singular_values() const { return parent::singular_values(); } + d_type singularValues() const { return parent::singular_values(); } - m_type U(int r) const { return parent::U(r); } + m_type matrixU(int r) const { return parent::U(r); } - m_type D(int r) const { return parent::D(r); } + m_type matrixD(int r) const { return parent::D(r); } - m_type V(int r) const { return parent::V(r); } + m_type matrixV(int r) const { return parent::V(r); } - m_type singular_values(int r) const { return parent::singular_values(r); } + m_type singularValues(int r) const { return parent::singular_values(r); } }; } diff --git a/include/rotgen/algebra/svd/fixed.hpp b/include/rotgen/algebra/svd/fixed.hpp index 5adcbaf..0230898 100644 --- a/include/rotgen/algebra/svd/fixed.hpp +++ b/include/rotgen/algebra/svd/fixed.hpp @@ -31,28 +31,28 @@ namespace rotgen int rank() const { return svd_.rank(); } - auto singular_values() const + auto singularValues() const { if constexpr (!use_expression_templates) return detail::as_concrete_t{svd_.singularValues()}; else return svd_.singularValues(); } - auto U() const + auto matrixU() const { if constexpr (!use_expression_templates) return detail::as_concrete_t{svd_.matrixU()}; else return svd_.matrixU(); } - auto V() const + auto matrixV() const { if constexpr (!use_expression_templates) return detail::as_concrete_t{svd_.matrixV()}; else return svd_.matrixV(); } - auto D() const + auto matrixD() const { auto d = svd_.singularValues().asDiagonal(); if constexpr (!use_expression_templates) @@ -61,7 +61,7 @@ namespace rotgen else return d; } - auto singular_values(int r) const + auto singularValues(int r) const { auto that = svd_.singularValues().head(r); if constexpr (!use_expression_templates) @@ -69,7 +69,7 @@ namespace rotgen else return svd_.singularValues(); } - auto U(int r) const + auto matrixU(int r) const { auto that = svd_.matrixU().leftCols(r); if constexpr (!use_expression_templates) @@ -77,7 +77,7 @@ namespace rotgen else return that; } - auto V(int r) const + auto matrixV(int r) const { auto that = svd_.matrixV().leftCols(r); if constexpr (!use_expression_templates) @@ -85,7 +85,7 @@ namespace rotgen else return that; } - auto D(int r) const + auto matrixD(int r) const { auto d = svd_.singularValues().head(r).asDiagonal(); if constexpr (!use_expression_templates) diff --git a/include/rotgen/concepts.hpp b/include/rotgen/concepts.hpp index 9029597..26f4773 100644 --- a/include/rotgen/concepts.hpp +++ b/include/rotgen/concepts.hpp @@ -12,6 +12,26 @@ namespace rotgen::concepts { + //================================================================================================ + //! @brief Check if a type is a Rotgen block. + //================================================================================================ + template + concept block = + requires { typename std::remove_cvref_t::rotgen_block_tag; }; + + //================================================================================================ + //! @brief Check if a type is a Rotgen matrix. + //================================================================================================ + template + concept matrix = + requires { typename std::remove_cvref_t::rotgen_matrix_tag; }; + + //================================================================================================ + //! @brief Check if a type is a Rotgen map. + //================================================================================================ + template + concept map = requires { typename std::remove_cvref_t::rotgen_map_tag; }; + //================================================================================================ //! @brief Check if a type is a Rotgen reference. //================================================================================================ @@ -36,10 +56,8 @@ namespace rotgen::concepts //! @brief Check if a type is a ROTGEN type. //================================================================================================ template - concept entity = requires(T const&) { - typename std::remove_cvref_t::rotgen_tag; - typename std::remove_cvref_t::parent; - }; + concept entity = + requires(T const&) { typename std::remove_cvref_t::rotgen_tag; }; //================================================================================================ //! @brief Check if a type is an EIGEN type. diff --git a/include/rotgen/container/block/dynamic.hpp b/include/rotgen/container/block/dynamic.hpp index 504a15e..e9b5830 100644 --- a/include/rotgen/container/block/dynamic.hpp +++ b/include/rotgen/container/block/dynamic.hpp @@ -7,11 +7,12 @@ //================================================================================================== #pragma once +#include + #include #include #include -#include #include namespace rotgen @@ -76,14 +77,37 @@ namespace rotgen using parent::operator=; - block& operator=(concepts::entity auto const& other) + template + block& operator=(Src const& other) requires(!is_immutable) { - assert(parent::rows() == other.rows() && parent::cols() == other.cols()); - for (rotgen::Index r = 0; r < parent::rows(); ++r) - for (rotgen::Index c = 0; c < parent::cols(); ++c) - (*this)(r, c) = other(r, c); + if constexpr (IsVectorAtCompileTime && Src::IsVectorAtCompileTime) + { + ROTGEN_ASSERT(parent::size() == other.size(), + "Block assignment from 1D source doesn't match size"); + for (rotgen::Index i = 0; i < parent::size(); ++i) + (*this)[i] = other[i]; + } + else if constexpr (IsVectorAtCompileTime && !Src::IsVectorAtCompileTime) + { + auto r = other.rows(); + auto c = other.cols(); + ROTGEN_ASSERT((r == 1 || c == 1), "Block assignment from dynamic sized " + "source doesn't match static size"); + + for (rotgen::Index i = 0; i < parent::size(); ++i) + (*this)[i] = other(r == 1 ? 0 : i, c == 1 ? 0 : i); + } + else + { + ROTGEN_ASSERT(parent::rows() == other.rows() && + parent::cols() == other.cols(), + "Block assignment size mismatch"); + for (rotgen::Index r = 0; r < parent::rows(); ++r) + for (rotgen::Index c = 0; c < parent::cols(); ++c) + (*this)(r, c) = other(r, c); + } return *this; } @@ -167,13 +191,6 @@ namespace rotgen block(parent const& base) : parent(base) {} - bool is_contiguous_linear() const - { - if (parent::innerStride() != 1) return false; - if constexpr (IsRowMajor) return parent::outerStride() == parent::cols(); - else return parent::outerStride() == parent::rows(); - } - value_type& operator()(Index i, Index j) requires(!is_immutable) { @@ -183,7 +200,6 @@ namespace rotgen value_type& operator()(Index i) requires(!is_immutable && IsVectorAtCompileTime) { - assert(is_contiguous_linear()); return parent::operator()(i); } @@ -201,7 +217,6 @@ namespace rotgen value_type operator()(Index i) const requires(IsVectorAtCompileTime) { - assert(is_contiguous_linear()); return parent::operator()(i); } @@ -211,7 +226,7 @@ namespace rotgen return (*this)(i); } - auto evaluate() const { return *this; } + concrete_type evaluate() const { return concrete_type{*this}; } decltype(auto) noalias() const { return *this; } @@ -278,17 +293,19 @@ namespace rotgen return static_cast(lhs) == static_cast(rhs); } - block& operator+=(block const& rhs) + template + block& operator+=(E const& rhs) requires(!is_immutable) { - base() += static_cast(rhs); + base() += rhs.base(); return *this; } - block& operator-=(block const& rhs) + template + block& operator-=(E const& rhs) requires(!is_immutable) { - base() -= static_cast(rhs); + base() -= rhs.base(); return *this; } @@ -297,10 +314,11 @@ namespace rotgen return concrete_type(static_cast(*this).operator-()); } - block& operator*=(block const& rhs) + template + block& operator*=(E const& rhs) requires(!is_immutable) { - base() *= static_cast(rhs); + base() *= rhs.base(); return *this; } @@ -351,11 +369,11 @@ namespace rotgen static concrete_type Zero(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Zero(rows, cols); } @@ -368,11 +386,11 @@ namespace rotgen static concrete_type Ones(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Ones(rows, cols); } @@ -385,11 +403,11 @@ namespace rotgen static concrete_type Constant(int rows, int cols, value_type value) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Constant(rows, cols, static_cast(value)); } @@ -402,11 +420,11 @@ namespace rotgen static concrete_type Random(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Random(rows, cols); } @@ -419,11 +437,11 @@ namespace rotgen static concrete_type Identity(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Identity(rows, cols); } @@ -464,7 +482,7 @@ namespace rotgen template value_type lpNorm() const { - assert(P == 1 || P == 2 || P == Infinity); + static_assert(P == 1 || P == 2 || P == Infinity); return parent::lpNorm(P); } diff --git a/include/rotgen/container/block/dynamic/indirect.hpp b/include/rotgen/container/block/dynamic/indirect.hpp index ad7ae77..b081f55 100644 --- a/include/rotgen/container/block/dynamic/indirect.hpp +++ b/include/rotgen/container/block/dynamic/indirect.hpp @@ -2,28 +2,36 @@ #define TYPE double #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) +#define CLASSCONSTNAME ROTGEN_MATRIX_NAME(block_const_impl, SIZE, _col) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #define TRANSNAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #define MAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _col) +#define TRANSMAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _row) #include #undef CLASSNAME +#undef CLASSCONSTNAME #undef TRANSCLASSNAME #undef TRANSNAME #undef SOURCENAME #undef MAPNAME +#undef TRANSMAPNAME #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) +#define CLASSCONSTNAME ROTGEN_MATRIX_NAME(block_const_impl, SIZE, _row) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #define TRANSNAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #define MAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _row) +#define TRANSMAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _col) #include #undef CLASSNAME +#undef CLASSCONSTNAME #undef TRANSCLASSNAME #undef TRANSNAME #undef SOURCENAME #undef MAPNAME +#undef TRANSMAPNAME #undef SIZE #undef TYPE @@ -32,28 +40,36 @@ #define TYPE float #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) +#define CLASSCONSTNAME ROTGEN_MATRIX_NAME(block_const_impl, SIZE, _col) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #define TRANSNAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #define MAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _col) +#define TRANSMAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _row) #include #undef CLASSNAME +#undef CLASSCONSTNAME #undef TRANSCLASSNAME #undef TRANSNAME #undef SOURCENAME #undef MAPNAME +#undef TRANSMAPNAME #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) +#define CLASSCONSTNAME ROTGEN_MATRIX_NAME(block_const_impl, SIZE, _row) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #define TRANSNAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #define MAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _row) +#define TRANSMAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _col) #include #undef CLASSNAME +#undef CLASSCONSTNAME #undef TRANSCLASSNAME #undef TRANSNAME #undef SOURCENAME #undef MAPNAME +#undef TRANSMAPNAME #undef SIZE #undef TYPE diff --git a/include/rotgen/container/block/dynamic/model.hpp b/include/rotgen/container/block/dynamic/model.hpp index ed04c01..24785fd 100644 --- a/include/rotgen/container/block/dynamic/model.hpp +++ b/include/rotgen/container/block/dynamic/model.hpp @@ -23,6 +23,7 @@ public: CLASSNAME(MAPNAME CONST& r, Index i0, Index j0, Index ni, Index nj); CLASSNAME(CLASSNAME CONST& r, Index i0, Index j0, Index ni, Index nj); CLASSNAME(TRANSCLASSNAME CONST& r, Index i0, Index j0, Index ni, Index nj); + CLASSNAME(TRANSMAPNAME CONST& r, Index i0, Index j0, Index ni, Index nj); CLASSNAME(CLASSNAME const& other); CLASSNAME(CLASSNAME&&) noexcept; @@ -81,8 +82,17 @@ public: TYPE& operator()(Index i, Index j); TYPE& operator()(Index index); CLASSNAME& operator+=(CLASSNAME const& rhs); + CLASSNAME& operator+=(CLASSCONSTNAME const& rhs); + CLASSNAME& operator+=(SOURCENAME const& rhs); + CLASSNAME& operator+=(TRANSNAME const& rhs); CLASSNAME& operator-=(CLASSNAME const& rhs); + CLASSNAME& operator-=(CLASSCONSTNAME const& rhs); + CLASSNAME& operator-=(SOURCENAME const& rhs); + CLASSNAME& operator-=(TRANSNAME const& rhs); CLASSNAME& operator*=(CLASSNAME const& rhs); + CLASSNAME& operator*=(CLASSCONSTNAME const& rhs); + CLASSNAME& operator*=(SOURCENAME const& rhs); + CLASSNAME& operator*=(TRANSNAME const& rhs); CLASSNAME& operator*=(TYPE d); CLASSNAME& operator/=(TYPE d); #endif diff --git a/include/rotgen/container/block/fixed.hpp b/include/rotgen/container/block/fixed.hpp index 9263b2d..5cb1e3d 100644 --- a/include/rotgen/container/block/fixed.hpp +++ b/include/rotgen/container/block/fixed.hpp @@ -193,11 +193,7 @@ namespace rotgen parent const& base() const { return static_cast(*this); } - auto evaluate() const - { - auto res = base().eval(); - return as_concrete_type(res); - } + auto evaluate() const { return concrete_type(base().eval()); } decltype(auto) noalias() const { @@ -298,11 +294,11 @@ namespace rotgen static concrete_type Constant(int rows, int cols, value_type value) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Constant(rows, cols, static_cast(value)); } @@ -315,11 +311,11 @@ namespace rotgen static concrete_type Identity(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Identity(rows, cols); } @@ -332,11 +328,11 @@ namespace rotgen static concrete_type Zero(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Zero(rows, cols); } @@ -349,11 +345,11 @@ namespace rotgen static concrete_type Ones(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Ones(rows, cols); } @@ -366,11 +362,11 @@ namespace rotgen static concrete_type Random(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Random(rows, cols); } @@ -523,17 +519,19 @@ namespace rotgen Index startCol() const { return base().startCol(); } - block& operator+=(block const& rhs) + template + block& operator+=(E const& rhs) requires(!is_immutable) { - static_cast(*this) += static_cast(rhs); + base() += rhs.base(); return *this; } - block& operator-=(block const& rhs) + template + block& operator-=(E const& rhs) requires(!is_immutable) { - static_cast(*this) -= static_cast(rhs); + base() -= rhs.base(); return *this; } @@ -542,10 +540,11 @@ namespace rotgen return concrete_type(static_cast(*this).operator-()); } - block& operator*=(block const& rhs) + template + block& operator*=(E const& rhs) requires(!is_immutable) { - static_cast(*this) *= static_cast(rhs); + base() *= rhs.base(); return *this; } diff --git a/include/rotgen/container/map/dynamic.hpp b/include/rotgen/container/map/dynamic.hpp index e9c7b55..f87e261 100644 --- a/include/rotgen/container/map/dynamic.hpp +++ b/include/rotgen/container/map/dynamic.hpp @@ -7,14 +7,14 @@ //================================================================================================== #pragma once +#include #include +#include #include #include #include -#include - namespace rotgen { namespace detail @@ -35,6 +35,7 @@ namespace rotgen using parent = find_map; using rotgen_tag = void; + using rotgen_map_tag = void; using value_type = typename std::remove_const_t::value_type; using concrete_type = typename std::remove_const_t::concrete_type; @@ -52,6 +53,7 @@ namespace rotgen static constexpr int ColsAtCompileTime = Ref::ColsAtCompileTime; static constexpr int MaxRowsAtCompileTime = Ref::MaxRowsAtCompileTime; static constexpr int MaxColsAtCompileTime = Ref::MaxColsAtCompileTime; + static constexpr int SizeAtCompileTime = Ref::SizeAtCompileTime; static constexpr bool IsVectorAtCompileTime = Ref::IsVectorAtCompileTime; static constexpr bool is_defined_static = RowsAtCompileTime != -1 && ColsAtCompileTime != -1; @@ -75,12 +77,12 @@ namespace rotgen : parent(ptr, r, c, strides(s, r, c)) { if constexpr (RowsAtCompileTime != -1) - assert(r == RowsAtCompileTime && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(r == RowsAtCompileTime, + "Mismatched between dynamic and static row size"); if constexpr (ColsAtCompileTime != -1) - assert(c == ColsAtCompileTime && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(c == ColsAtCompileTime, + "Mismatched between dynamic and static column size"); } // Used to properly delay ref dynamic construction @@ -135,24 +137,50 @@ namespace rotgen map(map const& other) : parent(other) {} + map(map&& other) : parent(std::move(other)) {} + map& operator=(map const& other) requires(!is_immutable) { - base() = static_cast(other); + base() = other.base(); return *this; } - map& operator=(concepts::entity auto const& other) + map& operator=(map&& other) requires(!is_immutable) { - assert(parent::rows() == other.rows() && parent::cols() == other.cols()); - if constexpr (IsVectorAtCompileTime) + base() = std::move(other.base()); + return *this; + } + + template + map& operator=(Src const& other) + requires(!is_immutable) + { + if constexpr (IsVectorAtCompileTime && Src::IsVectorAtCompileTime) { + ROTGEN_ASSERT(parent::size() == other.size(), + "Map assignment from 1D source doesn't match size"); for (rotgen::Index i = 0; i < parent::size(); ++i) - (*this)(i) = other(i); + (*this)[i] = other[i]; + } + else if constexpr (IsVectorAtCompileTime && !Src::IsVectorAtCompileTime) + { + auto r = other.rows(); + auto c = other.cols(); + + ROTGEN_ASSERT( + (r == 1 || c == 1), + "Map assignment from dynamic sized source doesn't match static size"); + + for (rotgen::Index i = 0; i < parent::size(); ++i) + (*this)[i] = other(r == 1 ? 0 : i, c == 1 ? 0 : i); } else { + ROTGEN_ASSERT(parent::rows() == other.rows() && + parent::cols() == other.cols(), + "Map assignment size mismatch"); for (rotgen::Index r = 0; r < parent::rows(); ++r) for (rotgen::Index c = 0; c < parent::cols(); ++c) (*this)(r, c) = other(r, c); @@ -196,7 +224,7 @@ namespace rotgen return (*this)(i); } - auto evaluate() const { return *this; } + concrete_type evaluate() const { return concrete_type{*this}; } decltype(auto) noalias() const { return *this; } @@ -275,7 +303,7 @@ namespace rotgen concrete_type cross(map const& other) const { concrete_type that; - if constexpr (RowsAtCompileTime == -1) + if constexpr (ColsAtCompileTime == 1) { that(0, 0) = (*this)(1, 0) * other(2, 0) - (*this)(2, 0) * other(1, 0); that(1, 0) = (*this)(2, 0) * other(0, 0) - (*this)(0, 0) * other(2, 0); @@ -312,6 +340,7 @@ namespace rotgen map& operator+=(map const& rhs) requires(!is_immutable) { + validate_compound_operator(rhs); base() += rhs.base(); return *this; } @@ -320,6 +349,7 @@ namespace rotgen map& operator-=(map const& rhs) requires(!is_immutable) { + validate_compound_operator(rhs); base() -= rhs.base(); return *this; } @@ -333,6 +363,9 @@ namespace rotgen map& operator*=(map const& rhs) requires(!is_immutable) { + ROTGEN_ASSERT(parent::cols() == rhs.rows() && + parent::cols() == rhs.cols(), + "Incompatible dimensions for compound matrix-product"); base() *= rhs.base(); return *this; } @@ -464,7 +497,7 @@ namespace rotgen template value_type lpNorm() const { - assert(P == 1 || P == 2 || P == Infinity); + static_assert(P == 1 || P == 2 || P == Infinity); return parent::lpNorm(P); } @@ -476,6 +509,32 @@ namespace rotgen { return concrete_type(base().qr_solve(rhs.base())); }; + + template + void validate_compound_operator(map const& rhs) + { + if constexpr (IsVectorAtCompileTime && R2::IsVectorAtCompileTime) + { + if constexpr (is_defined_static && R2::is_defined_static) + { + static_assert( + SizeAtCompileTime == R2::SizeAtCompileTime, + "Compile-time size mismatch in compound assignment operator"); + } + else + { + ROTGEN_ASSERT(parent::size() == rhs.size(), + "Size mismatch in compound assignment operator"); + } + } + else + { + ROTGEN_ASSERT(parent::rows() == rhs.rows(), + "Mismatched rows count in compound assignment operator"); + ROTGEN_ASSERT(parent::cols() == rhs.cols(), + "Mismatched cols count in compound assignment operator"); + } + } }; template @@ -499,18 +558,20 @@ namespace rotgen } template - matrix - operator*(map const& lhs, map const& rhs) + auto operator*(map const& lhs, map const& rhs) { using map1_type = map; using map2_type = map; using concrete_type = matrix; - return concrete_type(map1_type(lhs).base().mul(map2_type(rhs).base())); + if constexpr (concrete_type::SizeAtCompileTime == 0) return concrete_type{}; + else + { + auto p = concrete_type(map1_type(lhs).base().mul(map2_type(rhs).base())); + if constexpr (concrete_type::SizeAtCompileTime == 1) return product{p}; + else return p; + } } template diff --git a/include/rotgen/container/map/dynamic/indirect.hpp b/include/rotgen/container/map/dynamic/indirect.hpp index 90e0c50..ebd2801 100644 --- a/include/rotgen/container/map/dynamic/indirect.hpp +++ b/include/rotgen/container/map/dynamic/indirect.hpp @@ -3,27 +3,39 @@ #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) +#define TRANSCLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _row) +#define TRANSCLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _row) #define CLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _col) +#define CLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _col) #define TRANSSOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #include #undef CLASSNAME #undef TRANSCLASSNAME +#undef TRANSCLASSCONSTNAME +#undef TRANSCLASSNONCONSTNAME #undef TRANSSOURCENAME #undef SOURCENAME #undef CLASSCONSTNAME +#undef CLASSNONCONSTNAME #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) +#define TRANSCLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _col) +#define TRANSCLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _col) #define CLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _row) +#define CLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _row) #define TRANSSOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #include #undef CLASSNAME #undef TRANSCLASSNAME +#undef TRANSCLASSCONSTNAME +#undef TRANSCLASSNONCONSTNAME #undef TRANSSOURCENAME #undef SOURCENAME #undef CLASSCONSTNAME +#undef CLASSNONCONSTNAME #undef SIZE #undef TYPE @@ -33,27 +45,39 @@ #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) +#define TRANSCLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _row) +#define TRANSCLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _row) #define CLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _col) +#define CLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _col) #define TRANSSOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #include #undef CLASSNAME #undef TRANSCLASSNAME +#undef TRANSCLASSCONSTNAME +#undef TRANSCLASSNONCONSTNAME #undef TRANSSOURCENAME #undef SOURCENAME #undef CLASSCONSTNAME +#undef CLASSNONCONSTNAME #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) +#define TRANSCLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _col) +#define TRANSCLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _col) #define CLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _row) +#define CLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _row) #define TRANSSOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #include #undef CLASSNAME #undef TRANSCLASSNAME +#undef TRANSCLASSCONSTNAME +#undef TRANSCLASSNONCONSTNAME #undef TRANSSOURCENAME #undef SOURCENAME #undef CLASSCONSTNAME +#undef CLASSNONCONSTNAME #undef SIZE #undef TYPE diff --git a/include/rotgen/container/map/dynamic/model.hpp b/include/rotgen/container/map/dynamic/model.hpp index 7790efd..a8bb0f3 100644 --- a/include/rotgen/container/map/dynamic/model.hpp +++ b/include/rotgen/container/map/dynamic/model.hpp @@ -66,8 +66,10 @@ public: TYPE minCoeff() const; TYPE maxCoeff(Index*, Index*) const; TYPE minCoeff(Index*, Index*) const; - TYPE dot(CLASSNAME const&) const; - TYPE dot(TRANSCLASSNAME const&) const; + TYPE dot(CLASSNONCONSTNAME const&) const; + TYPE dot(CLASSCONSTNAME const&) const; + TYPE dot(TRANSCLASSCONSTNAME const&) const; + TYPE dot(TRANSCLASSNONCONSTNAME const&) const; TYPE squaredNorm() const; TYPE norm() const; @@ -86,10 +88,16 @@ public: #if !defined(USE_CONST) CLASSNAME& operator+=(CLASSNAME const& rhs); CLASSNAME& operator+=(CLASSCONSTNAME const& rhs); + CLASSNAME& operator+=(TRANSCLASSCONSTNAME const& rhs); + CLASSNAME& operator+=(TRANSCLASSNONCONSTNAME const& rhs); CLASSNAME& operator-=(CLASSNAME const& rhs); CLASSNAME& operator-=(CLASSCONSTNAME const& rhs); + CLASSNAME& operator-=(TRANSCLASSCONSTNAME const& rhs); + CLASSNAME& operator-=(TRANSCLASSNONCONSTNAME const& rhs); CLASSNAME& operator*=(CLASSNAME const& rhs); CLASSNAME& operator*=(CLASSCONSTNAME const& rhs); + CLASSNAME& operator*=(TRANSCLASSCONSTNAME const& rhs); + CLASSNAME& operator*=(TRANSCLASSNONCONSTNAME const& rhs); CLASSNAME& operator*=(TYPE d); CLASSNAME& operator/=(TYPE d); #endif diff --git a/include/rotgen/container/map/fixed.hpp b/include/rotgen/container/map/fixed.hpp index 70af3a2..9997a59 100644 --- a/include/rotgen/container/map/fixed.hpp +++ b/include/rotgen/container/map/fixed.hpp @@ -7,6 +7,8 @@ //================================================================================================== #pragma once +#include + #include #include @@ -41,6 +43,7 @@ namespace rotgen { public: using rotgen_tag = void; + using rotgen_map_tag = void; using parent = detail:: map_type, Options, std::is_const_v, Stride>; using value_type = typename std::remove_const_t::value_type; @@ -50,6 +53,7 @@ namespace rotgen static constexpr int ColsAtCompileTime = Ref::ColsAtCompileTime; static constexpr int MaxRowsAtCompileTime = Ref::MaxRowsAtCompileTime; static constexpr int MaxColsAtCompileTime = Ref::MaxColsAtCompileTime; + static constexpr int SizeAtCompileTime = Ref::SizeAtCompileTime; static constexpr bool IsVectorAtCompileTime = Ref::IsVectorAtCompileTime; static constexpr bool has_static_storage = Ref::has_static_storage; static constexpr bool IsRowMajor = Ref::IsRowMajor; @@ -112,21 +116,25 @@ namespace rotgen return *this; } + template + map& operator=(Eigen::MatrixBase const& other) + { + parent::operator=(other); + return *this; + } + + template + map& operator=(Eigen::EigenBase const& other) + { + parent::operator=(other); + return *this; + } + parent& base() { return static_cast(*this); } parent const& base() const { return static_cast(*this); } - auto evaluate() const - { - auto res = static_cast(*this).eval(); - return as_concrete_type(res); - } - - auto evaluate() - { - auto res = static_cast(*this).eval(); - return as_concrete_type(res); - } + auto evaluate() const { return concrete_type(base().eval()); } decltype(auto) noalias() const { @@ -176,6 +184,7 @@ namespace rotgen map& operator+=(map const& rhs) requires(!is_immutable) { + validate_compound_operator(rhs); base() += rhs.base(); return *this; } @@ -184,6 +193,7 @@ namespace rotgen map& operator-=(map const& rhs) requires(!is_immutable) { + validate_compound_operator(rhs); base() -= rhs.base(); return *this; } @@ -192,6 +202,9 @@ namespace rotgen map& operator*=(map const& rhs) requires(!is_immutable) { + ROTGEN_ASSERT(parent::cols() == rhs.rows() && + parent::cols() == rhs.cols(), + "Incompatible dimensions for compound matrix-product"); base() *= rhs.base(); return *this; } @@ -463,6 +476,32 @@ namespace rotgen static_assert(P == 1 || P == 2 || P == Infinity); return parent::template lpNorm

(); } + + template + void validate_compound_operator(map const& rhs) + { + if constexpr (IsVectorAtCompileTime && R2::IsVectorAtCompileTime) + { + if constexpr (is_defined_static && R2::is_defined_static) + { + static_assert( + SizeAtCompileTime == R2::SizeAtCompileTime, + "Compile-time size mismatch in compound assignment operator"); + } + else + { + ROTGEN_ASSERT(parent::size() == rhs.size(), + "Size mismatch in compound assignment operator"); + } + } + else + { + ROTGEN_ASSERT(parent::rows() == rhs.rows(), + "Mismatched rows count in compound assignment operator"); + ROTGEN_ASSERT(parent::cols() == rhs.cols(), + "Mismatched cols count in compound assignment operator"); + } + } }; template @@ -534,13 +573,14 @@ namespace rotgen } template - matrix - operator*(map const& lhs, map const& rhs) + auto operator*(map const& lhs, map const& rhs) { - using concrete_type = matrix; + auto p = lhs.base() * rhs.base(); + using concrete_type = detail::as_concrete_t; - return concrete_type(lhs.base() * rhs.base()); + if constexpr (concrete_type::SizeAtCompileTime == 1) + return product{concrete_type{p}}; + else return concrete_type{p}; } template diff --git a/include/rotgen/container/matrix/dynamic.hpp b/include/rotgen/container/matrix/dynamic.hpp index ae0d1a6..14e66d3 100644 --- a/include/rotgen/container/matrix/dynamic.hpp +++ b/include/rotgen/container/matrix/dynamic.hpp @@ -7,11 +7,12 @@ //================================================================================================== #pragma once +#include #include + #include #include -#include #include namespace rotgen @@ -27,6 +28,7 @@ namespace rotgen public: using parent = find_matrix; using rotgen_tag = void; + using rotgen_matrix_tag = void; using concrete_type = matrix; using value_type = Scalar; @@ -48,29 +50,40 @@ namespace rotgen static constexpr bool has_static_storage = false; static constexpr bool is_immutable = false; static constexpr int InnerStrideAtCompileTime = 1; - static constexpr int OuterStrideAtCompileTime = - IsRowMajor ? ColsAtCompileTime : RowsAtCompileTime; + static constexpr int OuterStrideAtCompileTime = IsRowMajor ? Cols : Rows; using transposed_type = matrix; - matrix() : parent(Rows == -1 ? 0 : Rows, Cols == -1 ? 0 : Cols) {} + static constexpr int AllocatedRows = + Rows == -1 ? (MaxRows == -1 ? 0 : MaxRows) : Rows; + static constexpr int AllocatedCols = + Cols == -1 ? (MaxCols == -1 ? 0 : MaxCols) : Cols; + + matrix() : parent(AllocatedRows, AllocatedCols) {} matrix(Index r, Index c) : parent(r, c) { if constexpr (Rows != -1) - assert(r == Rows && "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(r == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(c == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(c == Cols, + "Mismatched between dynamic and static column size"); } matrix(Index n) requires(IsVectorAtCompileTime && (Rows != 1 || Cols != 1)) - : parent(Rows != -1 ? 1 : n, Cols != -1 ? 1 : n) + : parent(Rows != 1 ? n : 1, Cols != 1 ? n : 1) { + if constexpr (Rows == 1 && Cols != -1) + ROTGEN_ASSERT(Cols == n, + "Mismatched between dynamic and static col size"); + if constexpr (Cols == 1 && Rows != -1) + ROTGEN_ASSERT(Rows == n, + "Mismatched between dynamic and static row size"); } - matrix(Scalar v) + explicit matrix(Scalar v) requires(Rows == 1 && Cols == 1) : parent(1, 1, {v}) { @@ -89,14 +102,14 @@ namespace rotgen : parent(init) { if constexpr (Rows != -1) - assert(init.size() == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(init.size() == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) { [[maybe_unused]] Index c = 0; if (init.size()) c = init.begin()->size(); - assert(c == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(c == Cols, + "Mismatched between dynamic and static column size"); } } @@ -106,31 +119,28 @@ namespace rotgen { } - matrix(concepts::entity auto const& e) : parent(e.rows(), e.cols()) + template matrix(Src const& other) : parent() { - if constexpr (Rows != -1) - assert(e.rows() == Rows && - "Mismatched between dynamic and static row size"); - if constexpr (Cols != -1) - assert(e.cols() == Cols && - "Mismatched between dynamic and static col size"); - for (rotgen::Index r = 0; r < e.rows(); ++r) - for (rotgen::Index c = 0; c < e.cols(); ++c) (*this)(r, c) = e(r, c); + if constexpr (IsVectorAtCompileTime && Src::IsVectorAtCompileTime) + { + resize(other.size()); + for (rotgen::Index i = 0; i < parent::size(); ++i) + (*this)[i] = other[i]; + } + else + { + resize(other.rows(), other.cols()); + + for (rotgen::Index r = 0; r < parent::rows(); ++r) + for (rotgen::Index c = 0; c < parent::cols(); ++c) + (*this)(r, c) = other(r, c); + } } - matrix& operator=(concepts::entity auto const& e) + matrix& operator=(concepts::entity auto const& other) { - if constexpr (Rows != -1) - assert(e.rows() == Rows && - "Mismatched between dynamic and static row size"); - if constexpr (Cols != -1) - assert(e.cols() == Cols && - "Mismatched between dynamic and static col size"); - resize(e.rows(), e.cols()); - - for (rotgen::Index r = 0; r < e.rows(); ++r) - for (rotgen::Index c = 0; c < e.cols(); ++c) (*this)(r, c) = e(r, c); - + matrix local(other); + swap(local); return *this; } @@ -146,6 +156,13 @@ namespace rotgen return (*this)(i); } + void swap(matrix& other) + { + // TODO: Swap elements per elements if statically defined to preserve + // data location in memory as with actual statically defines matrix + base().swap(other.base()); + } + auto evaluate() const { return *this; } decltype(auto) noalias() const { return *this; } @@ -156,17 +173,22 @@ namespace rotgen Index outerStride() const noexcept { - return IsVectorAtCompileTime ? this->size() - : IsRowMajor ? this->cols() - : this->rows(); + if constexpr (IsVectorAtCompileTime) return this->size(); + else + { + if constexpr (IsRowMajor) return this->cols(); + else return this->rows(); + } } void resize(int r, int c) { - if constexpr (Rows == 1) - assert(c == Cols && "Mismatched between dynamic and static col size"); - if constexpr (Cols == 1) - assert(r == Rows && "Mismatched between dynamic and static row size"); + if constexpr (Cols != -1) + ROTGEN_ASSERT(c == Cols, + "Mismatched between dynamic and static col size"); + if constexpr (Rows != -1) + ROTGEN_ASSERT(r == Rows, + "Mismatched between dynamic and static row size"); parent::resize(r, c); } @@ -179,10 +201,12 @@ namespace rotgen void conservativeResize(int r, int c) { - if constexpr (Rows == 1) - assert(c == Cols && "Mismatched between dynamic and static col size"); - if constexpr (Cols == 1) - assert(r == Rows && "Mismatched between dynamic and static row size"); + if constexpr (Cols != -1) + ROTGEN_ASSERT(c == Cols, + "Mismatched between dynamic and static col size"); + if constexpr (Rows != -1) + ROTGEN_ASSERT(r == Rows, + "Mismatched between dynamic and static row size"); parent::conservativeResize(r, c); } @@ -231,26 +255,26 @@ namespace rotgen friend bool operator==(matrix const& lhs, matrix const& rhs) { - return static_cast(lhs) == static_cast(rhs); + return lhs.base() == rhs.base(); } matrix& operator+=(matrix const& rhs) { - base() += static_cast(rhs); + base() += rhs.base(); return *this; } matrix& operator-=(matrix const& rhs) { - base() -= static_cast(rhs); + base() -= rhs.base(); return *this; } - matrix operator-() const { return matrix(base().operator-()); } + matrix operator-() const { return -base(); } matrix& operator*=(matrix const& rhs) { - base() *= static_cast(rhs); + base() *= rhs.base(); return *this; } @@ -299,11 +323,11 @@ namespace rotgen static matrix Ones(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Ones(rows, cols); } @@ -316,11 +340,11 @@ namespace rotgen static matrix Zero(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Zero(rows, cols); } @@ -333,11 +357,11 @@ namespace rotgen static matrix Constant(int rows, int cols, Scalar value) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Constant(rows, cols, static_cast(value)); } @@ -350,11 +374,11 @@ namespace rotgen static matrix Random(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Random(rows, cols); } @@ -367,11 +391,11 @@ namespace rotgen static matrix Identity(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Identity(rows, cols); } @@ -446,51 +470,4 @@ namespace rotgen parent const& base() const { return static_cast(*this); } }; - - template - matrix operator+(matrix const& lhs, - matrix const& rhs) - { - matrix that(lhs); - return that += rhs; - } - - template - matrix operator-(matrix const& lhs, - matrix const& rhs) - { - matrix that(lhs); - return that -= rhs; - } - - template - matrix operator*(matrix const& lhs, - matrix const& rhs) - { - matrix that(lhs); - return that *= rhs; - } - - template - matrix operator*(matrix const& lhs, - double rhs) - { - matrix that(lhs); - return that *= rhs; - } - - template - matrix operator*(double lhs, - matrix const& rhs) - { - return rhs * lhs; - } - - template - matrix operator/(matrix const& lhs, - double rhs) - { - matrix that(lhs); - return that /= rhs; - } } diff --git a/include/rotgen/container/matrix/dynamic/model.hpp b/include/rotgen/container/matrix/dynamic/model.hpp index 2520239..df5c427 100644 --- a/include/rotgen/container/matrix/dynamic/model.hpp +++ b/include/rotgen/container/matrix/dynamic/model.hpp @@ -99,6 +99,8 @@ public: void setRandom(Index rows, Index cols); void setIdentity(Index rows, Index cols); + void swap(CLASSNAME& other) { storage_.swap(other.storage_); } + private: struct payload; std::unique_ptr storage_; diff --git a/include/rotgen/container/matrix/fixed.hpp b/include/rotgen/container/matrix/fixed.hpp index 6c11b67..7b8981a 100644 --- a/include/rotgen/container/matrix/fixed.hpp +++ b/include/rotgen/container/matrix/fixed.hpp @@ -7,6 +7,7 @@ //================================================================================================== #pragma once +#include #include #include @@ -40,6 +41,7 @@ namespace rotgen { public: using rotgen_tag = void; + using rotgen_matrix_tag = void; using parent = detail::storage_type; using value_type = Scalar; @@ -49,6 +51,9 @@ namespace rotgen using concrete_type = matrix; using concrete_dynamic_type = matrix; + using exact_base = + Eigen::Matrix; + static constexpr int RowsAtCompileTime = Rows; static constexpr int ColsAtCompileTime = Cols; static constexpr int SizeAtCompileTime = detail::static_size(); @@ -75,6 +80,11 @@ namespace rotgen static constexpr bool has_static_storage = storage_status; + static constexpr int AllocatedRows = + Rows == -1 ? (MaxRows == -1 ? 0 : MaxRows) : Rows; + static constexpr int AllocatedCols = + Cols == -1 ? (MaxCols == -1 ? 0 : MaxCols) : Cols; + public: matrix() requires(has_static_storage) @@ -83,19 +93,19 @@ namespace rotgen matrix() requires(!has_static_storage) - : parent(Rows > 0 ? Rows : 0, Cols > 0 ? Cols : 0) + : parent(AllocatedRows, AllocatedCols) { } matrix(Index r, Index c) : parent(r, c) { if constexpr (RowsAtCompileTime != -1) - assert(r == RowsAtCompileTime && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(r == RowsAtCompileTime, + "Mismatched between dynamic and static row size"); if constexpr (ColsAtCompileTime != -1) - assert(c == ColsAtCompileTime && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(c == ColsAtCompileTime, + "Mismatched between dynamic and static column size"); } matrix(matrix const& other) = default; @@ -181,11 +191,7 @@ namespace rotgen parent const& base() const { return static_cast(*this); } - auto evaluate() const - { - auto res = base().eval(); - return as_concrete_type(res); - } + auto evaluate() const { return *this; } decltype(auto) noalias() const { @@ -273,38 +279,36 @@ namespace rotgen void resize(int s) requires(IsVectorAtCompileTime) { - if constexpr (Rows == 1) - assert(s == Cols && "Mismatched between dynamic and static col size"); - if constexpr (Cols == 1) - assert(s == Rows && "Mismatched between dynamic and static row size"); - parent::resize(s); + if constexpr (Rows == 1) parent::resize(1, s); + else parent::resize(s, 1); } void resize(int r, int c) { - if constexpr (Rows == 1) - assert(c == Cols && "Mismatched between dynamic and static col size"); - if constexpr (Cols == 1) - assert(r == Rows && "Mismatched between dynamic and static row size"); + if constexpr (Cols != -1) + ROTGEN_ASSERT(c == Cols, + "Mismatched between dynamic and static col size"); + if constexpr (Rows != -1) + ROTGEN_ASSERT(r == Rows, + "Mismatched between dynamic and static row size"); parent::resize(r, c); } void conservativeResize(int s) requires(IsVectorAtCompileTime) { - if constexpr (Rows == 1) - assert(s == Cols && "Mismatched between dynamic and static col size"); - if constexpr (Cols == 1) - assert(s == Rows && "Mismatched between dynamic and static row size"); - parent::conservativeResize(s); + if constexpr (Rows == 1) parent::conservativeResize(1, s); + else parent::conservativeResize(s, 1); } void conservativeResize(int r, int c) { - if constexpr (Rows == 1) - assert(c == Cols && "Mismatched between dynamic and static col size"); - if constexpr (Cols == 1) - assert(r == Rows && "Mismatched between dynamic and static row size"); + if constexpr (Cols != -1) + ROTGEN_ASSERT(c == Cols, + "Mismatched between dynamic and static col size"); + if constexpr (Rows != -1) + ROTGEN_ASSERT(r == Rows, + "Mismatched between dynamic and static row size"); parent::conservativeResize(r, c); } @@ -317,11 +321,11 @@ namespace rotgen static matrix Constant(int rows, int cols, Scalar value) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Constant(rows, cols, static_cast(value)); } @@ -334,11 +338,11 @@ namespace rotgen static matrix Identity(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Identity(rows, cols); } @@ -351,11 +355,11 @@ namespace rotgen static matrix Ones(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Ones(rows, cols); } @@ -368,11 +372,11 @@ namespace rotgen static matrix Zero(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Zero(rows, cols); } @@ -385,11 +389,11 @@ namespace rotgen static matrix Random(int rows, int cols) { if constexpr (Rows != -1) - assert(rows == Rows && - "Mismatched between dynamic and static row size"); + ROTGEN_ASSERT(rows == Rows, + "Mismatched between dynamic and static row size"); if constexpr (Cols != -1) - assert(cols == Cols && - "Mismatched between dynamic and static column size"); + ROTGEN_ASSERT(cols == Cols, + "Mismatched between dynamic and static column size"); return parent::Random(rows, cols); } @@ -519,7 +523,7 @@ namespace rotgen return *this; } - matrix operator-() const { return matrix(base()(*this).operator-()); } + matrix operator-() const { return -base(); } matrix& operator*=(matrix const& rhs) { diff --git a/include/rotgen/container/ref.hpp b/include/rotgen/container/ref.hpp index 2d99197..001dbbb 100644 --- a/include/rotgen/container/ref.hpp +++ b/include/rotgen/container/ref.hpp @@ -7,12 +7,13 @@ //================================================================================================== #pragma once +#include + #include #include #include #include -#include #include #include diff --git a/include/rotgen/container/ref/dynamic.hpp b/include/rotgen/container/ref/dynamic.hpp index 015e587..6be959a 100644 --- a/include/rotgen/container/ref/dynamic.hpp +++ b/include/rotgen/container/ref/dynamic.hpp @@ -7,7 +7,8 @@ //================================================================================================== #pragma once -#include +#include + #include #if !defined(ROTGEN_FORCE_DYNAMIC) @@ -15,6 +16,7 @@ #endif #include +#include namespace rotgen { @@ -149,6 +151,7 @@ namespace rotgen using parent::Zero; using parent::operator=; + using parent::operator-; template auto qr_solve(ref rhs) const @@ -166,12 +169,18 @@ namespace rotgen parent& as_map() { return base(); } + template + requires(is_immutable) + ref(product const& m) : ref(m.storage_) + { + } + template requires(detail::accept_as_ref>) ref(matrix& m) : parent(detail::postpone{}) { [[maybe_unused]] bool correct_ref_setup = detail::validate_ref(*this, m); - assert(correct_ref_setup); + ROTGEN_ASSERT(correct_ref_setup, "Invalid reference binding"); } template @@ -202,7 +211,7 @@ namespace rotgen ref(block&& b) : parent(detail::postpone{}) { [[maybe_unused]] bool correct_ref_setup = detail::validate_ref(*this, b); - assert(correct_ref_setup); + ROTGEN_ASSERT(correct_ref_setup, "Invalid reference binding"); } template @@ -210,7 +219,7 @@ namespace rotgen ref(block& b) : parent(detail::postpone{}) { [[maybe_unused]] bool correct_ref_setup = detail::validate_ref(*this, b); - assert(correct_ref_setup); + ROTGEN_ASSERT(correct_ref_setup, "Invalid reference binding"); } template @@ -229,7 +238,7 @@ namespace rotgen ref(map& b) : parent(detail::postpone{}) { [[maybe_unused]] bool correct_ref_setup = detail::validate_ref(*this, b); - assert(correct_ref_setup); + ROTGEN_ASSERT(correct_ref_setup, "Invalid reference binding"); } template @@ -259,15 +268,15 @@ namespace rotgen ref(ref& b) : parent(detail::postpone{}) { [[maybe_unused]] bool correct_ref_setup = detail::validate_ref(*this, b); - assert(correct_ref_setup); + ROTGEN_ASSERT(correct_ref_setup, "Invalid reference binding"); } template - requires(detail::same_scalar>) + requires(detail::accept_as_ref>) ref(ref const& b) : parent(detail::postpone{}) { [[maybe_unused]] bool correct_ref_setup = detail::validate_ref(*this, b); - assert(correct_ref_setup); + ROTGEN_ASSERT(correct_ref_setup, "Invalid reference binding"); } ref(parent& m) : parent(m.data(), m.rows(), m.cols()) {} diff --git a/include/rotgen/container/ref/fixed.hpp b/include/rotgen/container/ref/fixed.hpp index 3d5521d..a49c0cf 100644 --- a/include/rotgen/container/ref/fixed.hpp +++ b/include/rotgen/container/ref/fixed.hpp @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include @@ -52,6 +53,9 @@ namespace rotgen template using compile_ref_t = typename compile_ref::type; + + template + using compile_base_t = typename compile_ref::base; } template @@ -59,6 +63,7 @@ namespace rotgen { public: using parent = detail::compile_ref_t; + using exact_base = detail::compile_base_t; using referee = std::remove_const_t; using value_type = typename referee::value_type; using rotgen_tag = void; @@ -91,17 +96,7 @@ namespace rotgen using parent::size; // Aliasing handling - auto evaluate() const - { - auto res = static_cast(*this).eval(); - return as_concrete_type(res); - } - - auto evaluate() - { - auto res = static_cast(*this).eval(); - return as_concrete_type(res); - } + auto evaluate() const { return T(base().eval()); } decltype(auto) noalias() const { @@ -116,10 +111,27 @@ namespace rotgen } // Numeric functions - using parent::cwiseAbs; - using parent::cwiseAbs2; - using parent::cwiseInverse; - using parent::cwiseSqrt; + auto operator-() const { return detail::concretize(-base()); } + + auto cwiseAbs() const + { + return detail::concretize(base().cwiseAbs()); + } + + auto cwiseAbs2() const + { + return detail::concretize(base().cwiseAbs2()); + } + + auto cwiseInverse() const + { + return detail::concretize(base().cwiseInverse()); + } + + auto cwiseSqrt() const + { + return detail::concretize(base().cwiseSqrt()); + } // Reductions using parent::lpNorm; @@ -172,10 +184,26 @@ namespace rotgen } // Shape modifications - using parent::adjoint; - using parent::conjugate; - using parent::normalized; - using parent::transpose; + auto normalized() const + requires(IsVectorAtCompileTime) + { + return detail::concretize(base().normalized()); + } + + auto transpose() const + { + return detail::concretize(base().transpose()); + } + + auto adjoint() const + { + return detail::concretize(base().adjoint()); + } + + auto conjugate() const + { + return detail::concretize(base().conjugate()); + } // In-place Shape modifications using parent::adjointInPlace; @@ -183,16 +211,79 @@ namespace rotgen using parent::transposeInPlace; // Generators - using parent::Constant; - using parent::Identity; - using parent::Ones; - using parent::Random; - using parent::setConstant; - using parent::setIdentity; - using parent::setOnes; - using parent::setRandom; - using parent::setZero; - using parent::Zero; + static auto Zero() { return detail::concretize(parent::Zero()); } + + static auto Zero(int rows, int cols) + { + return detail::concretize(parent::Zero(rows, cols)); + } + + static auto Ones() { return detail::concretize(parent::Ones()); } + + static auto Ones(int rows, int cols) + { + return detail::concretize(parent::Ones(rows, cols)); + } + + static auto Constant(value_type value) + { + return detail::concretize(parent::Constant(value)); + } + + static auto Constant(int rows, int cols, value_type value) + { + return detail::concretize(parent::Constant(rows, cols, value)); + } + + static auto Random() + { + return detail::concretize(parent::Random()); + } + + static auto Random(int rows, int cols) + { + return detail::concretize(parent::Random(rows, cols)); + } + + static auto Identity() + { + return detail::concretize(parent::Identity()); + } + + static auto Identity(int rows, int cols) + { + return detail::concretize(parent::Identity(rows, cols)); + } + + ref& setOnes() + { + base() = parent::Ones(base().rows(), base().cols()); + return *this; + } + + ref& setZero() + { + base() = parent::Zero(base().rows(), base().cols()); + return *this; + } + + ref& setConstant(value_type value) + { + base() = parent::Constant(base().rows(), base().cols(), value); + return *this; + } + + ref& setRandom() + { + base() = parent::Random(base().rows(), base().cols()); + return *this; + } + + ref& setIdentity() + { + base() = parent::Identity(base().rows(), base().cols()); + return *this; + } auto qr_solve(auto const& rhs) const { @@ -212,6 +303,12 @@ namespace rotgen parent& base() { return static_cast(*this); } + template + requires(is_immutable) + ref(product const& m) : ref(m.storage_) + { + } + template ref(matrix& m) requires(requires { parent(m.base()); }) @@ -307,13 +404,14 @@ namespace rotgen // Deduction Guides //============================================================================ template - ref(matrix&) -> ref>; + ref(matrix&) -> ref>; template ref(block& b) -> ref; template - ref(matrix const&) -> ref const>; + ref(matrix const&) + -> ref const>; template ref(block const& b) -> ref; @@ -336,7 +434,14 @@ namespace rotgen template auto operator*(ref lhs, ref rhs) { - return detail::concretize(lhs.base() * rhs.base()); + auto p = lhs.base() * rhs.base(); + using concrete_type = detail::as_concrete_t; + + if constexpr (concrete_type::SizeAtCompileTime == 1) + return product{concrete_type{p}}; + else if constexpr (concrete_type::SizeAtCompileTime == 0) + return concrete_type{}; + else return concrete_type{p}; } template @@ -411,7 +516,6 @@ namespace rotgen template auto cross(ref lhs, ref rhs) - -> decltype(lhs.base().cross(rhs.base())) { return detail::concretize(lhs.base().cross(rhs.base())); } @@ -432,13 +536,8 @@ namespace rotgen using type = std::conditional_t, ref>; }; - template auto const& base_of(T const& a) + template decltype(auto) base_of(T&& a) { - return a; - } - - template auto& base_of(T& a) - { - return a; + return ROTGEN_FWD(a); } } diff --git a/include/rotgen/container/ref/functions.hpp b/include/rotgen/container/ref/functions.hpp index 4fa278d..eb62875 100644 --- a/include/rotgen/container/ref/functions.hpp +++ b/include/rotgen/container/ref/functions.hpp @@ -12,46 +12,49 @@ namespace rotgen { template - bool operator==(ref lhs, ref rhs) + bool operator==(ref const& lhs, ref const& rhs) { return lhs.base() == rhs.base(); } template - bool operator!=(ref lhs, ref rhs) + bool operator!=(ref const& lhs, ref const& rhs) { return lhs.base() != rhs.base(); } template auto operator*(std::convertible_to auto s, - ref rhs) + ref const& rhs) { + // void* _ = rhs; return rhs * s; } template - auto dot(ref lhs, ref rhs) + auto dot(ref const& lhs, ref const& rhs) { return lhs.base().dot(rhs.base()); } template - auto mul(ref lhs, std::convertible_to auto s) + auto mul(ref const& lhs, + std::convertible_to auto s) -> decltype(lhs * s) { return lhs * s; } template - auto mul(std::convertible_to auto s, ref rhs) - -> decltype(s * rhs) + auto mul(std::convertible_to auto s, + ref const& rhs) -> decltype(s * rhs) { return s * rhs; } template - auto div(ref lhs, std::convertible_to auto s) + auto div(ref const& lhs, + std::convertible_to auto s) -> decltype(lhs / s) { return lhs / s; diff --git a/include/rotgen/container/ref/generalize.hpp b/include/rotgen/container/ref/generalize.hpp index b5ad520..93a24bb 100644 --- a/include/rotgen/container/ref/generalize.hpp +++ b/include/rotgen/container/ref/generalize.hpp @@ -32,11 +32,12 @@ namespace rotgen template struct generalize { - static constexpr bool is_const = std::is_const_v; - using base = matrix; + static constexpr bool is_const = + std::is_const_v>; + using base = matrix::value_type, + std::remove_cvref_t::RowsAtCompileTime, + std::remove_cvref_t::ColsAtCompileTime, + std::remove_cvref_t::storage_order>; using type = std::conditional_t, ref>; }; diff --git a/include/rotgen/detail/accept_as_ref.hpp b/include/rotgen/detail/accept_as_ref.hpp index 0d71a22..b738224 100644 --- a/include/rotgen/detail/accept_as_ref.hpp +++ b/include/rotgen/detail/accept_as_ref.hpp @@ -7,7 +7,8 @@ //================================================================================================== #pragma once -#include +#include + #include namespace rotgen::detail @@ -77,22 +78,26 @@ namespace rotgen::detail if (Ref::RowsAtCompileTime == 1) { - assert(in.rows() == 1 || in.cols() == 1); + ROTGEN_ASSERT(in.rows() == 1 || in.cols() == 1, + "Incompatible rows/cols in ref binding"); rows = 1; cols = in.size(); } else if (Ref::ColsAtCompileTime == 1) { - assert(in.rows() == 1 || in.cols() == 1); + ROTGEN_ASSERT(in.rows() == 1 || in.cols() == 1, + "Incompatible rows/cols in ref binding"); rows = in.size(); cols = 1; } // Verify that the sizes are valid. - assert((Ref::RowsAtCompileTime == Dynamic) || - (Ref::RowsAtCompileTime == rows)); - assert((Ref::ColsAtCompileTime == Dynamic) || - (Ref::ColsAtCompileTime == cols)); + ROTGEN_ASSERT((Ref::RowsAtCompileTime == Dynamic) || + (Ref::RowsAtCompileTime == rows), + "Incompatible static rows/cols in ref binding"); + ROTGEN_ASSERT((Ref::ColsAtCompileTime == Dynamic) || + (Ref::ColsAtCompileTime == cols), + "Incompatible static rows/cols in ref binding"); // Swap stride if we are a vector and we changed rows as such bool transpose = Ref::IsVectorAtCompileTime && (rows != in.rows()); diff --git a/include/rotgen/detail/assert.hpp b/include/rotgen/detail/assert.hpp new file mode 100644 index 0000000..7469c98 --- /dev/null +++ b/include/rotgen/detail/assert.hpp @@ -0,0 +1,22 @@ +//================================================================================================== +/* + ROTGEN - Runtime Overlay for Eigen + Copyright : CODE RECKONS + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#pragma once + +#if defined(ROTGEN_USE_LIBASSERT) +#include +#define ROTGEN_ASSERT(COND, ...) DEBUG_ASSERT(COND, __VA_ARGS__) + +#else +#include + +#if !defined(NDEBUG) +#define ROTGEN_ASSERT(COND, MSG, ...) assert((COND) && (MSG)) +#else +#define ROTGEN_ASSERT(COND, ...) (void)(COND) +#endif +#endif diff --git a/include/rotgen/detail/helpers.hpp b/include/rotgen/detail/helpers.hpp index 17fd00c..5c839d9 100644 --- a/include/rotgen/detail/helpers.hpp +++ b/include/rotgen/detail/helpers.hpp @@ -31,7 +31,7 @@ namespace rotgen::detail using type = Wrapper; }; diff --git a/include/rotgen/detail/product.hpp b/include/rotgen/detail/product.hpp new file mode 100644 index 0000000..f5f1293 --- /dev/null +++ b/include/rotgen/detail/product.hpp @@ -0,0 +1,57 @@ +//================================================================================================== +/* + ROTGEN - Runtime Overlay for Eigen + Copyright : CODE RECKONS + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#pragma once + +#include + +#include + +namespace rotgen +{ + // Emulate EIGEN 1x1 pseudo product type + template struct product + { + using rotgen_tag = void; + + using value_type = typename M::value_type; + using concrete_type = matrix; + + static constexpr auto storage_order = concrete_type::storage_order; + static constexpr int RowsAtCompileTime = 1; + static constexpr int ColsAtCompileTime = 1; + static constexpr int SizeAtCompileTime = 1; + static constexpr int MaxRowsAtCompileTime = 1; + static constexpr int MaxColsAtCompileTime = 1; + static constexpr bool IsVectorAtCompileTime = + concrete_type::IsVectorAtCompileTime; + + product(M const& m) : storage_(m) {} + + auto size() const { return storage_.size(); } + + auto rows() const { return storage_.rows(); } + + auto cols() const { return storage_.cols(); } + + auto operator()(int i) const { return storage_(i); } + + auto operator[](int i) const { return storage_(i); } + + auto operator()(int r, int c) const { return storage_(r, c); } + + auto sum() const { return storage_.sum(); } + + auto const& base() const { return storage_.base(); } + + operator value_type const() const { return storage_(0); } + + operator concrete_type() const { return storage_; } + + concrete_type storage_; + }; +} diff --git a/include/rotgen/functions/extract.hpp b/include/rotgen/functions/extract.hpp index 73e7e4f..e10c457 100644 --- a/include/rotgen/functions/extract.hpp +++ b/include/rotgen/functions/extract.hpp @@ -21,10 +21,12 @@ namespace rotgen [[maybe_unused]] Index ni, [[maybe_unused]] Index nj) { - assert(i0 >= 0 && "block extraction uses negative row index."); - assert(j0 >= 0 && "block extraction uses negative col index."); - assert(i0 + ni <= e.rows() && "block extraction rows is out of range."); - assert(j0 + nj <= e.cols() && "block extraction cols is out of range."); + ROTGEN_ASSERT(i0 >= 0, "block extraction uses negative row index."); + ROTGEN_ASSERT(j0 >= 0, "block extraction uses negative col index."); + ROTGEN_ASSERT(i0 + ni <= e.rows(), + "block extraction rows is out of range."); + ROTGEN_ASSERT(j0 + nj <= e.cols(), + "block extraction cols is out of range."); } } diff --git a/include/rotgen/functions/generators.hpp b/include/rotgen/functions/generators.hpp index b91e2d1..96c6f34 100644 --- a/include/rotgen/functions/generators.hpp +++ b/include/rotgen/functions/generators.hpp @@ -7,10 +7,24 @@ //================================================================================================== #pragma once +#include #include namespace rotgen { + //----------------------------------------------------------------------------------------------- + //----------------------------------------------------------------------------------------------- + template void initialize_with(T&& m, Args... v) + { + using type = typename std::remove_cvref_t::value_type; + using map_t = rotgen::map>; + + ROTGEN_ASSERT(sizeof...(v) == m.size(), + "Incorrect quantity of coefficients for initialization"); + type data[] = {static_cast(v)...}; + m = map_t(data, m.rows(), m.cols()); + } + //----------------------------------------------------------------------------------------------- // Generators //----------------------------------------------------------------------------------------------- diff --git a/include/rotgen/functions/operators.hpp b/include/rotgen/functions/operators.hpp index e56fb82..afd66bc 100644 --- a/include/rotgen/functions/operators.hpp +++ b/include/rotgen/functions/operators.hpp @@ -7,9 +7,10 @@ //================================================================================================== #pragma once +#include + #include -#include #include namespace rotgen @@ -80,30 +81,27 @@ namespace rotgen //------------------------------------------------------------------------------------------------ // Compounds operators across types - template - auto operator+=(A& a, B const& b) - requires(concepts::entity && concepts::entity) + template + decltype(auto) operator+=(A&& a, B const& b) + requires(!concepts::block) { - if constexpr (!use_expression_templates) - return generalize_t(a) += generalize_t(b); - else return base_of(a) += base_of(b); + generalize_t(ROTGEN_FWD(a)) += generalize_t(b); + return ROTGEN_FWD(a); } - template - auto operator-=(A& a, B const& b) - requires(concepts::entity && concepts::entity) + template + decltype(auto) operator-=(A&& a, B const& b) + requires(!concepts::block) { - if constexpr (!use_expression_templates) - return generalize_t(a) -= generalize_t(b); - else return base_of(a) -= base_of(b); + generalize_t(ROTGEN_FWD(a)) -= generalize_t(b); + return ROTGEN_FWD(a); } - template - auto operator*=(A& a, B const& b) - requires(concepts::entity && concepts::entity) + template + decltype(auto) operator*=(A&& a, B const& b) + requires(!concepts::block) { - if constexpr (!use_expression_templates) - return generalize_t(a) *= generalize_t(b); - else return base_of(a) *= base_of(b); + generalize_t(ROTGEN_FWD(a)) *= generalize_t(b); + return ROTGEN_FWD(a); } } diff --git a/src/block/generate.cpp b/src/block/generate.cpp index 3d3faf0..43b9bdc 100644 --- a/src/block/generate.cpp +++ b/src/block/generate.cpp @@ -10,30 +10,38 @@ #define STORAGE_ORDER Eigen::ColMajor #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) +#define CLASSCONSTNAME ROTGEN_MATRIX_NAME(block_const_impl, SIZE, _col) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #define TRANSNAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #define MAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _col) +#define TRANSMAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _row) #include "model.cpp" #undef CLASSNAME +#undef CLASSCONSTNAME #undef TRANSNAME #undef TRANSCLASSNAME #undef SOURCENAME #undef MAPNAME +#undef TRANSMAPNAME #undef STORAGE_ORDER #define STORAGE_ORDER Eigen::RowMajor #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) +#define CLASSCONSTNAME ROTGEN_MATRIX_NAME(block_const_impl, SIZE, _row) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #define TRANSNAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #define MAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _row) +#define TRANSMAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _col) #include "model.cpp" #undef CLASSNAME +#undef CLASSCONSTNAME #undef TRANSNAME #undef TRANSCLASSNAME #undef SOURCENAME #undef MAPNAME +#undef TRANSMAPNAME #undef STORAGE_ORDER #undef SIZE @@ -44,30 +52,38 @@ #define STORAGE_ORDER Eigen::ColMajor #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) +#define CLASSCONSTNAME ROTGEN_MATRIX_NAME(block_const_impl, SIZE, _col) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #define TRANSNAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #define MAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _col) +#define TRANSMAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _row) #include "model.cpp" #undef CLASSNAME +#undef CLASSCONSTNAME #undef TRANSNAME #undef TRANSCLASSNAME #undef SOURCENAME #undef MAPNAME +#undef TRANSMAPNAME #undef STORAGE_ORDER #define STORAGE_ORDER Eigen::RowMajor #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) +#define CLASSCONSTNAME ROTGEN_MATRIX_NAME(block_const_impl, SIZE, _row) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #define TRANSNAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #define MAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _row) +#define TRANSMAPNAME ROTGEN_MATRIX_NAME(BASEMAP, SIZE, _col) #include "model.cpp" #undef CLASSNAME +#undef CLASSCONSTNAME #undef TRANSNAME #undef TRANSCLASSNAME #undef SOURCENAME #undef MAPNAME +#undef TRANSMAPNAME #undef STORAGE_ORDER #undef SIZE diff --git a/src/block/model.cpp b/src/block/model.cpp index dc01f1d..4c8ac4e 100644 --- a/src/block/model.cpp +++ b/src/block/model.cpp @@ -33,6 +33,12 @@ CLASSNAME::CLASSNAME(MAPNAME CONST& r, Index i0, Index j0, Index ni, Index nj) { } +CLASSNAME::CLASSNAME( + TRANSMAPNAME CONST& r, Index i0, Index j0, Index ni, Index nj) + : storage_(std::make_unique(r, i0, j0, ni, nj, map_t{}, trans_t{})) +{ +} + // We're building a block from a block - So we have to dig around the internals CLASSNAME::CLASSNAME( TRANSCLASSNAME CONST& p, Index i0, Index j0, Index ni, Index nj) @@ -197,17 +203,17 @@ TYPE CLASSNAME::operator()(Index i, Index j) const #if !defined(USE_CONST) TYPE& CLASSNAME::operator()(Index index) { - TYPE* ptr = nullptr; - storage_->apply([&](auto& blk) { ptr = blk.data() + index; }); - return *ptr; + auto r = rows() == 1 ? 0 : index; + auto c = cols() == 1 ? 0 : index; + return (*this)(r, c); } #endif TYPE CLASSNAME::operator()(Index index) const { - TYPE ptr; - storage_->apply([&](auto const& blk) { ptr = *(blk.data() + index); }); - return ptr; + auto r = rows() == 1 ? 0 : index; + auto c = cols() == 1 ? 0 : index; + return (*this)(r, c); } // Raw pointer access @@ -448,6 +454,28 @@ CLASSNAME& CLASSNAME::operator+=(CLASSNAME const& rhs) return *this; } +CLASSNAME& CLASSNAME::operator+=(CLASSCONSTNAME const& rhs) +{ + std::visit( + [](auto& lhs_blk, auto const& rhs_blk) { lhs_blk.first += rhs_blk.first; }, + storage_->data, rhs.storage()->data); + return *this; +} + +CLASSNAME& CLASSNAME::operator+=(SOURCENAME const& rhs) +{ + std::visit([&](auto& lhs_blk) { lhs_blk.first += rhs.storage()->data; }, + storage_->data); + return *this; +} + +CLASSNAME& CLASSNAME::operator+=(TRANSNAME const& rhs) +{ + std::visit([&](auto& lhs_blk) { lhs_blk.first += rhs.storage()->data; }, + storage_->data); + return *this; +} + CLASSNAME& CLASSNAME::operator-=(CLASSNAME const& rhs) { std::visit( @@ -456,6 +484,28 @@ CLASSNAME& CLASSNAME::operator-=(CLASSNAME const& rhs) return *this; } +CLASSNAME& CLASSNAME::operator-=(CLASSCONSTNAME const& rhs) +{ + std::visit( + [](auto& lhs_blk, auto const& rhs_blk) { lhs_blk.first -= rhs_blk.first; }, + storage_->data, rhs.storage()->data); + return *this; +} + +CLASSNAME& CLASSNAME::operator-=(SOURCENAME const& rhs) +{ + std::visit([&](auto& lhs_blk) { lhs_blk.first -= rhs.storage()->data; }, + storage_->data); + return *this; +} + +CLASSNAME& CLASSNAME::operator-=(TRANSNAME const& rhs) +{ + std::visit([&](auto& lhs_blk) { lhs_blk.first -= rhs.storage()->data; }, + storage_->data); + return *this; +} + CLASSNAME& CLASSNAME::operator*=(CLASSNAME const& rhs) { std::visit( @@ -464,6 +514,28 @@ CLASSNAME& CLASSNAME::operator*=(CLASSNAME const& rhs) return *this; } +CLASSNAME& CLASSNAME::operator*=(CLASSCONSTNAME const& rhs) +{ + std::visit( + [](auto& lhs_blk, auto const& rhs_blk) { lhs_blk.first *= rhs_blk.first; }, + storage_->data, rhs.storage()->data); + return *this; +} + +CLASSNAME& CLASSNAME::operator*=(SOURCENAME const& rhs) +{ + std::visit([&](auto& lhs_blk) { lhs_blk.first *= rhs.storage()->data; }, + storage_->data); + return *this; +} + +CLASSNAME& CLASSNAME::operator*=(TRANSNAME const& rhs) +{ + std::visit([&](auto& lhs_blk) { lhs_blk.first *= rhs.storage()->data; }, + storage_->data); + return *this; +} + CLASSNAME& CLASSNAME::operator*=(TYPE s) { storage_->apply([&](auto& blk) { blk *= s; }); diff --git a/src/map/generate.cpp b/src/map/generate.cpp index c6ede29..6cdf555 100644 --- a/src/map/generate.cpp +++ b/src/map/generate.cpp @@ -11,29 +11,41 @@ #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) +#define TRANSCLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _row) +#define TRANSCLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _row) #define TRANSSOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #define CLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _col) +#define CLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _col) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #include "model.cpp" #undef CLASSNAME #undef TRANSCLASSNAME +#undef TRANSCLASSCONSTNAME +#undef TRANSCLASSNONCONSTNAME #undef TRANSSOURCENAME #undef CLASSCONSTNAME +#undef CLASSNONCONSTNAME #undef SOURCENAME #undef STORAGE_ORDER #define STORAGE_ORDER Eigen::RowMajor #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) +#define TRANSCLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _col) +#define TRANSCLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _col) #define TRANSSOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #define CLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _row) +#define CLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _row) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #include "model.cpp" #undef CLASSNAME #undef TRANSCLASSNAME +#undef TRANSCLASSCONSTNAME +#undef TRANSCLASSNONCONSTNAME #undef TRANSSOURCENAME #undef SOURCENAME #undef CLASSCONSTNAME +#undef CLASSNONCONSTNAME #undef STORAGE_ORDER #undef SIZE @@ -45,28 +57,40 @@ #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) +#define TRANSCLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _row) +#define TRANSCLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _row) #define TRANSSOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #define CLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _col) +#define CLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _col) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #include "model.cpp" #undef CLASSNAME #undef TRANSCLASSNAME +#undef TRANSCLASSCONSTNAME +#undef TRANSCLASSNONCONSTNAME #undef TRANSSOURCENAME #undef CLASSCONSTNAME +#undef CLASSNONCONSTNAME #undef SOURCENAME #undef STORAGE_ORDER #define STORAGE_ORDER Eigen::RowMajor #define CLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _row) #define TRANSCLASSNAME ROTGEN_MATRIX_NAME(BASENAME, SIZE, _col) +#define TRANSCLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _col) +#define TRANSCLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _col) #define TRANSSOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _col) #define CLASSCONSTNAME ROTGEN_MATRIX_NAME(map_const_impl, SIZE, _row) +#define CLASSNONCONSTNAME ROTGEN_MATRIX_NAME(map_impl, SIZE, _row) #define SOURCENAME ROTGEN_MATRIX_NAME(matrix_impl, SIZE, _row) #include "model.cpp" #undef CLASSNAME #undef TRANSCLASSNAME +#undef TRANSCLASSCONSTNAME +#undef TRANSCLASSNONCONSTNAME #undef TRANSSOURCENAME #undef CLASSCONSTNAME +#undef CLASSNONCONSTNAME #undef SOURCENAME #undef STORAGE_ORDER diff --git a/src/map/model.cpp b/src/map/model.cpp index 1aefec6..cc31a69 100644 --- a/src/map/model.cpp +++ b/src/map/model.cpp @@ -239,12 +239,22 @@ void CLASSNAME::adjointInPlace() } #endif -TYPE CLASSNAME::dot(CLASSNAME const& rhs) const +TYPE CLASSNAME::dot(CLASSNONCONSTNAME const& rhs) const { return storage_->data.reshaped().dot(rhs.storage()->data.reshaped()); } -TYPE CLASSNAME::dot(TRANSCLASSNAME const& rhs) const +TYPE CLASSNAME::dot(CLASSCONSTNAME const& rhs) const +{ + return storage_->data.reshaped().dot(rhs.storage()->data.reshaped()); +} + +TYPE CLASSNAME::dot(TRANSCLASSNONCONSTNAME const& rhs) const +{ + return storage_->data.reshaped().dot(rhs.storage()->data.reshaped()); +} + +TYPE CLASSNAME::dot(TRANSCLASSCONSTNAME const& rhs) const { return storage_->data.reshaped().dot(rhs.storage()->data.reshaped()); } @@ -378,7 +388,7 @@ SOURCENAME CLASSNAME::operator-() const #if !defined(USE_CONST) CLASSNAME& CLASSNAME::operator+=(CLASSNAME const& rhs) { - storage_->data += rhs.storage_->data; + storage_->data += rhs.storage()->data; return *this; } @@ -388,9 +398,21 @@ CLASSNAME& CLASSNAME::operator+=(CLASSCONSTNAME const& rhs) return *this; } +CLASSNAME& CLASSNAME::operator+=(TRANSCLASSNONCONSTNAME const& rhs) +{ + storage_->data.reshaped() += rhs.storage()->data.reshaped(); + return *this; +} + +CLASSNAME& CLASSNAME::operator+=(TRANSCLASSCONSTNAME const& rhs) +{ + storage_->data.reshaped() += rhs.storage()->data.reshaped(); + return *this; +} + CLASSNAME& CLASSNAME::operator-=(CLASSNAME const& rhs) { - storage_->data -= rhs.storage_->data; + storage_->data -= rhs.storage()->data; return *this; } @@ -400,9 +422,21 @@ CLASSNAME& CLASSNAME::operator-=(CLASSCONSTNAME const& rhs) return *this; } +CLASSNAME& CLASSNAME::operator-=(TRANSCLASSNONCONSTNAME const& rhs) +{ + storage_->data.reshaped() -= rhs.storage()->data.reshaped(); + return *this; +} + +CLASSNAME& CLASSNAME::operator-=(TRANSCLASSCONSTNAME const& rhs) +{ + storage_->data.reshaped() -= rhs.storage()->data.reshaped(); + return *this; +} + CLASSNAME& CLASSNAME::operator*=(CLASSNAME const& rhs) { - storage_->data *= rhs.storage_->data; + storage_->data *= rhs.storage()->data; return *this; } @@ -412,6 +446,18 @@ CLASSNAME& CLASSNAME::operator*=(CLASSCONSTNAME const& rhs) return *this; } +CLASSNAME& CLASSNAME::operator*=(TRANSCLASSNONCONSTNAME const& rhs) +{ + storage_->data *= rhs.storage()->data; + return *this; +} + +CLASSNAME& CLASSNAME::operator*=(TRANSCLASSCONSTNAME const& rhs) +{ + storage_->data *= rhs.storage()->data; + return *this; +} + CLASSNAME& CLASSNAME::operator*=(TYPE s) { storage_->data *= s; diff --git a/test/integration/extract.cpp b/test/integration/extract.cpp index 3e0a33f..0884374 100644 --- a/test/integration/extract.cpp +++ b/test/integration/extract.cpp @@ -77,3 +77,39 @@ TTS_CASE("Extraction of ref/ref const") for (rotgen::Index r = 0; r < 4; r++) for (rotgen::Index c = 0; c < 3; c++) TTS_EQUAL(sliced(r, c), 5.f); }; + +void process_col(rotgen::matrix& m, + rotgen::matrix const& n, + int c) +{ + col(m, c) += n; +} + +TTS_CASE("Compound operators on extractions") +{ + rotgen::matrix m; + rotgen::matrix reference; + rotgen::matrix n; + setZero(m); + setConstant(n, 10); + setConstant(reference, 10); + + for (int i = 0; i < m.cols(); i++) process_col(m, n, i); + + TTS_EQUAL(m, reference); +}; + +TTS_CASE("Compatibility of 1D blocks") +{ + rotgen::matrix V; + setConstant(V, 99); + rotgen::matrix C{12, 34, 56}; + + extract<1, 2>(V, 0, 0) = segment<2>(C, 0); + extract<2, 1>(V, 1, 0) = segment<2>(C, 1); + + TTS_EQUAL(V(0, 0), C(0)); + TTS_EQUAL(V(0, 1), C(1)); + TTS_EQUAL(V(1, 0), C(1)); + TTS_EQUAL(V(2, 0), C(2)); +}; diff --git a/test/integration/initialize_with.cpp b/test/integration/initialize_with.cpp new file mode 100644 index 0000000..4e918f0 --- /dev/null +++ b/test/integration/initialize_with.cpp @@ -0,0 +1,123 @@ +//================================================================================================== +/* + ROTGEN - Runtime Overlay for Eigen + Copyright : CODE RECKONS + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#include + +#include "unit/tests.hpp" +#include + +TTS_CASE_TPL("Initialize a matrix with a list of scalars", rotgen::tests::types) + +(tts::type>) +{ + using eigen_t = Eigen::Matrix; + using rotgen_t = rotgen::matrix; + + eigen_t reference(3, 3); + reference << 1, 2, 3, 4, 5, 6, 7, 8, 9; + + rotgen_t values(3, 3); + initialize_with(values, 1, 2, 3, 4, 5, 6, 7, 8, 9); + + for (int r = 0; r < 3; r++) + for (int c = 0; c < 3; c++) TTS_EQUAL(values(r, c), reference(r, c)); +}; + +TTS_CASE_TPL("Initialize a sub-matrix with a list of scalars", + rotgen::tests::types) + +(tts::type>) +{ + using eigen_t = Eigen::Matrix; + using rotgen_t = rotgen::matrix; + + eigen_t reference(6, 6); + reference.block(1, 1, 3, 3) << 1, 2, 3, 4, 5, 6, 7, 8, 9; + + rotgen_t values(6, 6); + initialize_with(extract(values, 1, 1, 3, 3), 1, 2, 3, 4, 5, 6, 7, 8, 9); + + for (int r = 0; r < 3; r++) + for (int c = 0; c < 3; c++) + TTS_EQUAL(values(r + 1, c + 1), reference(r + 1, c + 1)); +}; + +TTS_CASE_TPL("Initialize a map with a list of scalars", rotgen::tests::types) + +(tts::type>) +{ + using eigen_t = Eigen::Matrix; + using rotgen_t = rotgen::matrix; + + T eigen_data[9] = {}; + T rotgen_data[9] = {}; + + Eigen::Map reference(eigen_data, 3, 3); + reference << 1, 2, 3, 4, 5, 6, 7, 8, 9; + + rotgen::map values(rotgen_data, 3, 3); + initialize_with(values, 1, 2, 3, 4, 5, 6, 7, 8, 9); + + for (int i = 0; i < 9; i++) TTS_EQUAL(eigen_data[i], rotgen_data[i]); +}; + +TTS_CASE_TPL("Initialize a sub-map with a list of scalars", + rotgen::tests::types) + +(tts::type>) +{ + using eigen_t = Eigen::Matrix; + using rotgen_t = rotgen::matrix; + + T eigen_data[36] = {}; + T rotgen_data[36] = {}; + + Eigen::Map reference(eigen_data, 6, 6); + reference.block(1, 1, 3, 3) << 1, 2, 3, 4, 5, 6, 7, 8, 9; + + rotgen::map values(rotgen_data, 6, 6); + initialize_with(extract(values, 1, 1, 3, 3), 1, 2, 3, 4, 5, 6, 7, 8, 9); + + for (int i = 0; i < 9; i++) TTS_EQUAL(eigen_data[i], rotgen_data[i]); +}; + +void process(rotgen::ref r) +{ + rotgen::initialize_with(r, 1, 2, 3, 4, 5, 6, 7, 8, 9); +} + +void process(rotgen::ref r) +{ + rotgen::initialize_with(r, 1, 2, 3, 4, 5, 6, 7, 8, 9); +} + +void process(rotgen::ref> r) +{ + rotgen::initialize_with(r, 1, 2, 3, 4, 5, 6, 7, 8, 9); +} + +void process(rotgen::ref> r) +{ + rotgen::initialize_with(r, 1, 2, 3, 4, 5, 6, 7, 8, 9); +} + +TTS_CASE_TPL("Initialize a ref with a list of scalars", rotgen::tests::types) + +(tts::type>) +{ + using eigen_t = Eigen::Matrix; + using rotgen_t = rotgen::matrix; + + eigen_t reference(3, 3); + reference << 1, 2, 3, 4, 5, 6, 7, 8, 9; + + rotgen_t values(3, 3); + process(values); + + for (int r = 0; r < 3; r++) + for (int c = 0; c < 3; c++) TTS_EQUAL(values(r, c), reference(r, c)); +}; diff --git a/test/integration/specifics.cpp b/test/integration/specifics.cpp new file mode 100644 index 0000000..0bdea4d --- /dev/null +++ b/test/integration/specifics.cpp @@ -0,0 +1,50 @@ +//================================================================================================== +/* + ROTGEN - Runtime Overlay for Eigen + Copyright : CODE RECKONS + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#include + +#include "unit/tests.hpp" + +bool categorize_as_scalar(double) +{ + return true; +} + +bool categorize_as_scalar(rotgen::ref>) +{ + return false; +} + +TTS_CASE("Matrix product of 1xN by Nx1 yields a scalar-convertible object") +{ + rotgen::matrix a = {1, 2}; + rotgen::matrix b = {10, 20}; + + double n = a * b; + + TTS_EQUAL(n, 50); + TTS_EXPECT(categorize_as_scalar(a * b)); +}; + +TTS_CASE_TPL("Static 1x1 matrix-like objects can be assigned as-if", + rotgen::tests::types) + +(tts::type>) +{ + rotgen::matrix a = {1, 2}; + rotgen::matrix b = {10, 20}; + + rotgen::matrix big(10, 10); + auto bk = extract<1, 1>(big, 2, 2); + bk = a * b; + TTS_EQUAL(big(2, 2), 50); + + rotgen::map> big_map(big.data(), 1, 1); + + big_map = a * b; + TTS_EQUAL(big(0, 0), 50); +}; diff --git a/test/unit/functions/svd.cpp b/test/unit/functions/svd.cpp index 23671a1..55a643c 100644 --- a/test/unit/functions/svd.cpp +++ b/test/unit/functions/svd.cpp @@ -9,9 +9,9 @@ #include "unit/tests.hpp" -TTS_CASE_TPL("SVD decomposition - Dynamic case", - rotgen::tests::types)( - tts::type>) +TTS_CASE_TPL("SVD decomposition - Dynamic case", rotgen::tests::types) + +(tts::type>) { int rank, i = 5; auto eps = std::numeric_limits::epsilon(); @@ -24,10 +24,10 @@ TTS_CASE_TPL("SVD decomposition - Dynamic case", { rank = decomp.rank(); - auto u = decomp.U(rank); - auto d = decomp.singular_values(rank); - auto dd = decomp.D(rank); - auto v = decomp.V(rank); + auto u = decomp.matrixU(rank); + auto d = decomp.singularValues(rank); + auto dd = decomp.matrixD(rank); + auto v = decomp.matrixV(rank); TTS_EQUAL(rank, i); @@ -48,9 +48,9 @@ TTS_CASE_TPL("SVD decomposition - Dynamic case", } while (rank != 1); }; -TTS_CASE_TPL("SVD decomposition - Static case", - rotgen::tests::types)( - tts::type>) +TTS_CASE_TPL("SVD decomposition - Static case", rotgen::tests::types) + +(tts::type>) { int rank, i = 5; auto eps = std::numeric_limits::epsilon(); @@ -62,10 +62,10 @@ TTS_CASE_TPL("SVD decomposition - Static case", { rank = decomp.rank(); - auto u = decomp.U(rank); - auto d = decomp.singular_values(rank); - auto dd = decomp.D(rank); - auto v = decomp.V(rank); + auto u = decomp.matrixU(rank); + auto d = decomp.singularValues(rank); + auto dd = decomp.matrixD(rank); + auto v = decomp.matrixV(rank); TTS_EQUAL(rank, i);