From 3e2e6f253ca4e69c7a13bbd6e9e7c722ddc148a6 Mon Sep 17 00:00:00 2001 From: Joel Falcou Date: Thu, 18 Sep 2025 16:25:51 +0200 Subject: [PATCH] Implement normalize and normalized See merge request oss/rotgen!24 --- include/rotgen/dynamic/block.hpp | 25 ++++++++++----------- include/rotgen/dynamic/map.hpp | 20 ++++++++--------- include/rotgen/dynamic/matrix.hpp | 10 +++++++++ include/rotgen/fixed/block.hpp | 31 +++++++++++++++++--------- include/rotgen/fixed/map.hpp | 11 ++++++++++ include/rotgen/fixed/matrix.hpp | 33 ++++++++++++++++++++-------- include/rotgen/functions.hpp | 10 +++++---- include/rotgen/impl/block_model.hpp | 8 ++++--- include/rotgen/impl/map_model.hpp | 8 ++++--- include/rotgen/impl/matrix_model.hpp | 8 ++++--- src/block_model.cpp | 22 ++++++++++++++----- src/map_model.cpp | 12 ++++++++++ src/matrix_model.cpp | 13 ++++++++++- test/unit/common/norms.hpp | 23 +++++++++++++++++-- 14 files changed, 169 insertions(+), 65 deletions(-) diff --git a/include/rotgen/dynamic/block.hpp b/include/rotgen/dynamic/block.hpp index 89145e8..24d7146 100644 --- a/include/rotgen/dynamic/block.hpp +++ b/include/rotgen/dynamic/block.hpp @@ -132,28 +132,25 @@ namespace rotgen decltype(auto) noalias() const { return *this; } decltype(auto) noalias() { return *this; } - concrete_type transpose() const + concrete_type normalized() const requires(IsVectorAtCompileTime) { - return concrete_type(static_cast(*this).transpose()); - } - - concrete_type conjugate() const - { - return concrete_type(static_cast(*this).conjugate()); - } - - concrete_type adjoint() const - { - return concrete_type(static_cast(*this).adjoint()); + return concrete_type(base().normalized()); } + concrete_type transpose() const { return concrete_type(base().transpose());} + concrete_type conjugate() const { return concrete_type(base().conjugate());} + concrete_type adjoint() const { return concrete_type(base().adjoint());} concrete_type cwiseAbs() const { return concrete_type(base().cwiseAbs()); } concrete_type cwiseAbs2() const { return concrete_type(base().cwiseAbs2()); } concrete_type cwiseInverse() const { return concrete_type(base().cwiseInverse()); } concrete_type cwiseSqrt() const { return concrete_type(base().cwiseSqrt()); } - void transposeInPlace() requires(!is_immutable) { parent::transposeInPlace(); } - void adjointInPlace() requires(!is_immutable) { parent::adjointInPlace(); } + void normalize() requires(!is_immutable && IsVectorAtCompileTime) + { + parent::normalize(); + } + void transposeInPlace() requires(!is_immutable) { parent::transposeInPlace(); } + void adjointInPlace() requires(!is_immutable) { parent::adjointInPlace(); } friend bool operator==(block const& lhs, block const& rhs) { diff --git a/include/rotgen/dynamic/map.hpp b/include/rotgen/dynamic/map.hpp index d176476..b7df712 100644 --- a/include/rotgen/dynamic/map.hpp +++ b/include/rotgen/dynamic/map.hpp @@ -123,26 +123,24 @@ namespace rotgen decltype(auto) noalias() const { return *this; } decltype(auto) noalias() { return *this; } - concrete_type transpose() const + concrete_type normalized() const requires(IsVectorAtCompileTime) { - return concrete_type(static_cast(*this).transpose()); + return concrete_type(base().normalized()); } - concrete_type conjugate() const - { - return concrete_type(static_cast(*this).conjugate()); - } - - concrete_type adjoint() const - { - return concrete_type(static_cast(*this).adjoint()); - } + concrete_type transpose() const { return concrete_type(base().transpose()); } + concrete_type conjugate() const { return concrete_type(base().conjugate()); } + concrete_type adjoint() const { return concrete_type(base().adjoint()); } concrete_type cwiseAbs() const { return concrete_type(base().cwiseAbs()); } concrete_type cwiseAbs2() const { return concrete_type(base().cwiseAbs2()); } concrete_type cwiseInverse() const { return concrete_type(base().cwiseInverse()); } concrete_type cwiseSqrt() const { return concrete_type(base().cwiseSqrt()); } + void normalize() requires(!is_immutable && IsVectorAtCompileTime) + { + parent::normalize(); + } void transposeInPlace() requires(!is_immutable) { parent::transposeInPlace(); } void adjointInPlace() requires(!is_immutable) { parent::adjointInPlace(); } diff --git a/include/rotgen/dynamic/matrix.hpp b/include/rotgen/dynamic/matrix.hpp index 6827284..dd31adf 100644 --- a/include/rotgen/dynamic/matrix.hpp +++ b/include/rotgen/dynamic/matrix.hpp @@ -112,6 +112,11 @@ namespace rotgen parent::conservativeResize(new_rows, new_cols); } + matrix normalized() const requires(IsVectorAtCompileTime) + { + return matrix(base().normalized()); + } + matrix transpose() const { return matrix(base().transpose()); @@ -127,6 +132,11 @@ namespace rotgen return matrix(base().adjoint()); } + void normalize() requires(IsVectorAtCompileTime) + { + parent::normalize(); + } + void transposeInPlace() { parent::transposeInPlace(); } void adjointInPlace() { parent::adjointInPlace(); } diff --git a/include/rotgen/fixed/block.hpp b/include/rotgen/fixed/block.hpp index 3d86509..8293af0 100644 --- a/include/rotgen/fixed/block.hpp +++ b/include/rotgen/fixed/block.hpp @@ -126,11 +126,11 @@ namespace rotgen } parent& base() { return static_cast(*this); } - parent const& base() const { return static_cast(*this); } + parent const& base() const { return static_cast(*this); } auto evaluate() const { - auto res = static_cast(*this).eval(); + auto res = base().eval(); return as_concrete_type(res); } @@ -146,22 +146,33 @@ namespace rotgen else return *this; } - auto transpose() const + auto normalized() const requires(IsVectorAtCompileTime) { - auto res = static_cast(*this).transpose(); - return as_concrete_type(res); + if constexpr(use_expression_templates) return base().normalized(); + else return as_concrete_type(base().normalized()); } - auto conjugate() const + auto transpose() const { - auto res = static_cast(*this).conjugate(); - return as_concrete_type(res); + if constexpr(use_expression_templates) return base().transpose(); + else return as_concrete_type(base().transpose()); } auto adjoint() const { - auto res = static_cast(*this).adjoint(); - return as_concrete_type(res); + if constexpr(use_expression_templates) return base().adjoint(); + else return as_concrete_type(base().adjoint()); + } + + auto conjugate() const + { + if constexpr(use_expression_templates) return base().conjugate(); + else return as_concrete_type(base().conjugate()); + } + + void normalize() requires(!is_immutable && IsVectorAtCompileTime) + { + parent::normalize(); } void transposeInPlace() requires(!is_immutable) { parent::transposeInPlace(); } diff --git a/include/rotgen/fixed/map.hpp b/include/rotgen/fixed/map.hpp index 68443cf..071ce07 100644 --- a/include/rotgen/fixed/map.hpp +++ b/include/rotgen/fixed/map.hpp @@ -185,6 +185,12 @@ namespace rotgen else return base().cwiseSqrt(); } + auto normalized() const requires(IsVectorAtCompileTime) + { + if constexpr(use_expression_templates) return base().normalized(); + else return as_concrete_type(base().normalized()); + } + auto transpose() const { if constexpr(use_expression_templates) return base().transpose(); @@ -203,6 +209,11 @@ namespace rotgen else return as_concrete_type(base().conjugate()); } + void normalize() requires(IsVectorAtCompileTime) + { + base().normalize(); + } + void transposeInPlace() { base().transposeInPlace(); } void adjointInPlace() { base().adjointInPlace(); } diff --git a/include/rotgen/fixed/matrix.hpp b/include/rotgen/fixed/matrix.hpp index 3ed4997..5b75e7f 100644 --- a/include/rotgen/fixed/matrix.hpp +++ b/include/rotgen/fixed/matrix.hpp @@ -126,11 +126,11 @@ namespace rotgen } parent& base() { return static_cast(*this); } - parent const& base() const { return static_cast(*this); } + parent const& base() const { return static_cast(*this); } auto evaluate() const { - auto res = static_cast(*this).eval(); + auto res = base().eval(); return as_concrete_type(res); } @@ -146,28 +146,43 @@ namespace rotgen else return *this; } + auto normalized() const requires(IsVectorAtCompileTime) + { + if constexpr(use_expression_templates) return base().normalized(); + else + { + auto res = base().normalized(); + return as_concrete_type(res); + } + } + auto transpose() const { if constexpr(use_expression_templates) return base().transpose(); else { - auto res = static_cast(*this).transpose(); + auto res = base().transpose(); return as_concrete_type(res); } } auto conjugate() const { - auto res = static_cast(*this).conjugate(); + auto res = base().conjugate(); return as_concrete_type(res); } auto adjoint() const { - auto res = static_cast(*this).adjoint(); + auto res = base().adjoint(); return as_concrete_type(res); } + void normalize() requires(IsVectorAtCompileTime) + { + parent::normalize(); + } + void transposeInPlace() { parent::transposeInPlace(); } void adjointInPlace() { parent::adjointInPlace(); } @@ -380,24 +395,24 @@ namespace rotgen matrix& operator+=(matrix const& rhs) { - static_cast(*this) += static_cast(rhs); + static_cast(*this) += rhs.base(); return *this; } matrix& operator-=(matrix const& rhs) { - static_cast(*this) -= static_cast(rhs); + static_cast(*this) -= rhs.base(); return *this; } matrix operator-() const { - return matrix(static_cast(*this).operator-()); + return matrix(base()(*this).operator-()); } matrix& operator*=(matrix const& rhs) { - static_cast(*this) *= static_cast(rhs); + static_cast(*this) *= rhs.base(); return *this; } diff --git a/include/rotgen/functions.hpp b/include/rotgen/functions.hpp index 9854f3f..b55ab21 100644 --- a/include/rotgen/functions.hpp +++ b/include/rotgen/functions.hpp @@ -27,12 +27,14 @@ namespace rotgen arg.conservativeResize(new_rows, new_cols); } - decltype(auto) transpose(concepts::entity auto const& arg) { return arg.transpose(); } - decltype(auto) conjugate(concepts::entity auto const& arg) { return arg.conjugate(); } - decltype(auto) adjoint (concepts::entity auto const& arg) { return arg.adjoint(); } + decltype(auto) normalized(concepts::entity auto const& arg) { return arg.normalized(); } + decltype(auto) transpose (concepts::entity auto const& arg) { return arg.transpose(); } + decltype(auto) conjugate (concepts::entity auto const& arg) { return arg.conjugate(); } + decltype(auto) adjoint (concepts::entity auto const& arg) { return arg.adjoint(); } + void normalize(concepts::entity auto& arg) { arg.normalize(); } void transposeInPlace(concepts::entity auto& arg) { arg.transposeInPlace(); } - void adjointInPlace(concepts::entity auto& arg) { arg.adjointInPlace(); } + void adjointInPlace(concepts::entity auto& arg) { arg.adjointInPlace(); } auto abs(concepts::entity auto const& arg) { return arg.cwiseAbs(); } auto abs2(concepts::entity auto const& arg) { return arg.cwiseAbs2(); } diff --git a/include/rotgen/impl/block_model.hpp b/include/rotgen/impl/block_model.hpp index e07222c..81a8930 100644 --- a/include/rotgen/impl/block_model.hpp +++ b/include/rotgen/impl/block_model.hpp @@ -46,9 +46,10 @@ class ROTGEN_EXPORT CLASSNAME Index startRow() const; Index startCol() const; - SOURCENAME transpose() const; - SOURCENAME conjugate() const; - SOURCENAME adjoint() const; + SOURCENAME normalized() const; + SOURCENAME transpose() const; + SOURCENAME conjugate() const; + SOURCENAME adjoint() const; SOURCENAME cwiseAbs() const; SOURCENAME cwiseAbs2() const; @@ -56,6 +57,7 @@ class ROTGEN_EXPORT CLASSNAME SOURCENAME cwiseSqrt() const; #if !defined(USE_CONST) + void normalize(); void transposeInPlace(); void adjointInPlace(); #endif diff --git a/include/rotgen/impl/map_model.hpp b/include/rotgen/impl/map_model.hpp index 11341f3..0819776 100644 --- a/include/rotgen/impl/map_model.hpp +++ b/include/rotgen/impl/map_model.hpp @@ -36,9 +36,10 @@ class ROTGEN_EXPORT CLASSNAME Index innerStride() const; Index outerStride() const; - SOURCENAME transpose() const; - SOURCENAME conjugate() const; - SOURCENAME adjoint() const; + SOURCENAME normalized() const; + SOURCENAME transpose() const; + SOURCENAME conjugate() const; + SOURCENAME adjoint() const; SOURCENAME cwiseAbs() const; SOURCENAME cwiseAbs2() const; @@ -46,6 +47,7 @@ class ROTGEN_EXPORT CLASSNAME SOURCENAME cwiseSqrt() const; #if !defined(USE_CONST) + void normalize(); void transposeInPlace(); void adjointInPlace(); #endif diff --git a/include/rotgen/impl/matrix_model.hpp b/include/rotgen/impl/matrix_model.hpp index d841e7c..102a3d8 100644 --- a/include/rotgen/impl/matrix_model.hpp +++ b/include/rotgen/impl/matrix_model.hpp @@ -35,15 +35,17 @@ class ROTGEN_EXPORT CLASSNAME void resize(std::size_t new_rows, std::size_t new_cols); void conservativeResize(std::size_t new_rows, std::size_t new_cols); - CLASSNAME transpose() const; - CLASSNAME conjugate() const; - CLASSNAME adjoint() const; + CLASSNAME normalized() const; + CLASSNAME transpose() const; + CLASSNAME conjugate() const; + CLASSNAME adjoint() const; CLASSNAME cwiseAbs() const; CLASSNAME cwiseAbs2() const; CLASSNAME cwiseInverse() const; CLASSNAME cwiseSqrt() const; + void normalize(); void transposeInPlace(); void adjointInPlace(); diff --git a/src/block_model.cpp b/src/block_model.cpp index 5def410..7bbf3a2 100644 --- a/src/block_model.cpp +++ b/src/block_model.cpp @@ -225,6 +225,13 @@ struct CLASSNAME::payload //================================================================================================== // Matrix operations //================================================================================================== + SOURCENAME CLASSNAME::normalized() const + { + SOURCENAME result; + storage_->apply([&](const auto& blk) { result.storage()->assign(blk.normalized().eval()); }); + return result; + } + SOURCENAME CLASSNAME::transpose() const { SOURCENAME result; @@ -239,6 +246,13 @@ struct CLASSNAME::payload return result; } + SOURCENAME CLASSNAME::adjoint() const + { + SOURCENAME result; + storage_->apply([&](const auto& blk) { result.storage()->assign(blk.adjoint().eval()); }); + return result; + } + SOURCENAME CLASSNAME::cwiseAbs() const { SOURCENAME result; @@ -267,14 +281,12 @@ struct CLASSNAME::payload return result; } - SOURCENAME CLASSNAME::adjoint() const + #if !defined(USE_CONST) + void CLASSNAME::normalize() { - SOURCENAME result; - storage_->apply([&](const auto& blk) { result.storage()->assign(blk.adjoint().eval()); }); - return result; + storage_->apply([](auto& blk) { blk.normalize(); }); } - #if !defined(USE_CONST) void CLASSNAME::transposeInPlace() { storage_->apply([](auto& blk) { blk.transposeInPlace(); }); diff --git a/src/map_model.cpp b/src/map_model.cpp index 504df59..8e78245 100644 --- a/src/map_model.cpp +++ b/src/map_model.cpp @@ -62,6 +62,13 @@ TYPE CLASSNAME::operator()(Index i, Index j) const { return storage_->data(i,j); } TYPE CLASSNAME::operator()(Index i) const { return storage_->data.data()[i]; } + SOURCENAME CLASSNAME::normalized() const + { + SOURCENAME result; + result.storage()->assign(storage_->data.normalized().eval()); + return result; + } + SOURCENAME CLASSNAME::transpose() const { SOURCENAME result; @@ -112,6 +119,11 @@ } #if !defined(USE_CONST) + void CLASSNAME::normalize() + { + storage_->data.normalize(); + } + void CLASSNAME::transposeInPlace() { storage_->data.transposeInPlace(); diff --git a/src/matrix_model.cpp b/src/matrix_model.cpp index c3fd6c7..d172bcf 100644 --- a/src/matrix_model.cpp +++ b/src/matrix_model.cpp @@ -75,12 +75,18 @@ TYPE const& CLASSNAME::operator()(std::size_t index) const { return storage_->da const TYPE* CLASSNAME::data() const { return storage_->data.data(); } TYPE* CLASSNAME::data() { return storage_->data.data(); } +CLASSNAME CLASSNAME::normalized() const +{ + CLASSNAME result(*this); + result.storage_->data.normalize(); + return result; +} + CLASSNAME CLASSNAME::transpose() const { CLASSNAME result(*this); result.storage_->data.transposeInPlace(); return result; - } CLASSNAME CLASSNAME::conjugate() const @@ -97,6 +103,11 @@ CLASSNAME CLASSNAME::adjoint() const return result; } +void CLASSNAME::normalize() +{ + storage_->data.normalize(); +} + void CLASSNAME::transposeInPlace() { storage_->data.transposeInPlace(); diff --git a/test/unit/common/norms.hpp b/test/unit/common/norms.hpp index c476ada..46cb158 100644 --- a/test/unit/common/norms.hpp +++ b/test/unit/common/norms.hpp @@ -18,6 +18,7 @@ namespace rotgen::tests using EigenMatrix = Eigen::Matrix; EigenMatrix ref(input.rows(), input.cols()); + prepare([&](auto r, auto c) { return input(r,c); }, ref); TTS_EQUAL(input.norm() , ref.norm()); @@ -26,10 +27,28 @@ namespace rotgen::tests TTS_EQUAL(input.template lpNorm<2>() , ref.template lpNorm<2>()); TTS_EQUAL(input.template lpNorm() , ref.template lpNorm()); - TTS_EQUAL(norm(input) , ref.norm()); - TTS_EQUAL(squaredNorm(input) , ref.squaredNorm()); + TTS_EQUAL(norm(input) , ref.norm()); + TTS_EQUAL(squaredNorm(input) , ref.squaredNorm()); TTS_EQUAL(lpNorm<1>(input) , ref.template lpNorm<1>()); TTS_EQUAL(lpNorm<2>(input) , ref.template lpNorm<2>()); TTS_EQUAL(lpNorm(input) , ref.template lpNorm()); + + if constexpr(T::IsVectorAtCompileTime) + { + EigenMatrix e_norm = ref.normalized(); + using mat_t = rotgen::matrix; + mat_t norm_ref(input.rows(), input.cols()); + prepare([&](auto r, auto c) { return e_norm(r,c); }, norm_ref); + + TTS_EQUAL(input.normalized(), norm_ref); + TTS_EQUAL(normalized(input), norm_ref); + auto m_norm = input; + m_norm.normalize(); + TTS_EQUAL(m_norm, norm_ref); + + auto f_norm = input; + normalize(f_norm); + TTS_EQUAL(f_norm, norm_ref); + } } }