From 62802273e4f87fb18fd50ffe7862f9f9b3baa919 Mon Sep 17 00:00:00 2001 From: Joel Falcou Date: Sat, 20 Sep 2025 16:44:45 +0200 Subject: [PATCH] Adapt generator static functions as free functions. --- include/rotgen/dynamic/matrix.hpp | 10 +- include/rotgen/functions.hpp | 155 +++++++++++++++++++++++++++--- test/unit/matrix/generators.cpp | 145 ++++++++++++++-------------- 3 files changed, 223 insertions(+), 87 deletions(-) diff --git a/include/rotgen/dynamic/matrix.hpp b/include/rotgen/dynamic/matrix.hpp index 5ae98d0..3274632 100644 --- a/include/rotgen/dynamic/matrix.hpp +++ b/include/rotgen/dynamic/matrix.hpp @@ -282,7 +282,7 @@ namespace rotgen matrix& setOnes() { - *this = parent::Ones(Rows, Cols); + *this = parent::Ones(parent::rows(), parent::cols()); return *this; } @@ -294,7 +294,7 @@ namespace rotgen matrix& setZero() { - *this = parent::Zero(Rows, Cols); + *this = parent::Zero(parent::rows(), parent::cols()); return *this; } @@ -306,7 +306,7 @@ namespace rotgen matrix& setConstant(Scalar value) { - *this = parent::Constant(Rows, Cols, static_cast(value)); + *this = parent::Constant(parent::rows(), parent::cols(), static_cast(value)); return *this; } @@ -318,7 +318,7 @@ namespace rotgen matrix& setRandom() { - *this = parent::Random(Rows, Cols); + *this = parent::Random(parent::rows(), parent::cols()); return *this; } @@ -330,7 +330,7 @@ namespace rotgen matrix& setIdentity() { - *this = parent::Identity(Rows, Cols); + *this = parent::Identity(parent::rows(), parent::cols()); return *this; } diff --git a/include/rotgen/functions.hpp b/include/rotgen/functions.hpp index 0e12ec4..0de28b2 100644 --- a/include/rotgen/functions.hpp +++ b/include/rotgen/functions.hpp @@ -9,6 +9,9 @@ namespace rotgen { + //----------------------------------------------------------------------------------------------- + // Infos & Shape + //----------------------------------------------------------------------------------------------- std::size_t rows(concepts::entity auto const& arg) { return arg.rows(); } std::size_t cols(concepts::entity auto const& arg) { return arg.cols(); } std::size_t size(concepts::entity auto const& arg) { return arg.size(); } @@ -39,6 +42,9 @@ namespace rotgen arg.conservativeResize(new_rows, new_cols); } + //----------------------------------------------------------------------------------------------- + // Global operations + //----------------------------------------------------------------------------------------------- decltype(auto) normalized(concepts::entity auto const& arg) { return arg.normalized(); } decltype(auto) transpose (concepts::entity auto const& arg) { return arg.transpose(); } decltype(auto) conjugate (concepts::entity auto const& arg) { return arg.conjugate(); } @@ -48,18 +54,17 @@ namespace rotgen void transposeInPlace(concepts::entity auto& arg) { arg.transposeInPlace(); } void adjointInPlace(concepts::entity auto& arg) { arg.adjointInPlace(); } - auto abs(concepts::entity auto const& arg) { return arg.cwiseAbs(); } - auto abs2(concepts::entity auto const& arg) { return arg.cwiseAbs2(); } - auto rec(concepts::entity auto const& arg) { return arg.cwiseInverse(); } - auto sqrt(concepts::entity auto const& arg) { return arg.cwiseSqrt(); } - - #if defined(ROTGEN_ENABLE_EXPRESSION_TEMPLATES) - decltype(auto) abs(concepts::eigen_compatible auto const& arg) { return arg.cwiseAbs(); } - decltype(auto) abs2(concepts::eigen_compatible auto const& arg) { return arg.cwiseAbs2(); } - decltype(auto) rec(concepts::eigen_compatible auto const& arg) { return arg.cwiseInverse(); } - decltype(auto) sqrt(concepts::eigen_compatible auto const& arg) { return arg.cwiseSqrt(); } - #endif + //----------------------------------------------------------------------------------------------- + // Component-wise functions + //----------------------------------------------------------------------------------------------- + auto abs (auto const& arg) requires(requires{arg.cwiseAbs();} ) { return arg.cwiseAbs(); } + auto abs2(auto const& arg) requires(requires{arg.cwiseAbs2();} ) { return arg.cwiseAbs2(); } + auto rec (auto const& arg) requires(requires{arg.cwiseInverse();}) { return arg.cwiseInverse(); } + auto sqrt(auto const& arg) requires(requires{arg.cwiseSqrt();} ) { return arg.cwiseSqrt(); } + //----------------------------------------------------------------------------------------------- + // Reductions + //----------------------------------------------------------------------------------------------- auto trace(concepts::entity auto const& arg) { return arg.trace(); } auto squaredNorm(concepts::entity auto const& arg) { return arg.squaredNorm(); } auto norm(concepts::entity auto const& arg) { return arg.norm(); } @@ -94,6 +99,9 @@ namespace rotgen return arg.template lpNorm

(); } + //----------------------------------------------------------------------------------------------- + // Expression handling + //----------------------------------------------------------------------------------------------- template decltype(auto) noalias(T&& t) requires( requires{std::forward(t).noalias();} ) { @@ -111,4 +119,127 @@ namespace rotgen { return std::forward(t).eval(); } -} + + //----------------------------------------------------------------------------------------------- + // Generators + //----------------------------------------------------------------------------------------------- + template + auto setZero(T&& t) requires( requires{std::forward(t).setZero();} ) + { + return std::forward(t).setZero(); + } + + template + auto setZero() requires( requires{T::Zero();} ) + { + return T::Zero(); + } + + template + auto setZero(std::integral auto n) requires( requires{T::Zero(n);} ) + { + return T::Zero(n); + } + + template + auto setZero(std::integral auto r,std::integral auto c) requires( requires{T::Zero(r,c);} ) + { + return T::Zero(r,c); + } + + template + auto setOnes(T&& t) requires( requires{std::forward(t).setOnes();} ) + { + return std::forward(t).setOnes(); + } + + template + auto setOnes() requires( requires{T::Ones();} ) + { + return T::Ones(); + } + + template + auto setOnes(std::integral auto n) requires( requires{T::Ones(n);} ) + { + return T::Ones(n); + } + + template + auto setOnes(std::integral auto r,std::integral auto c) requires( requires{T::Ones(r,c);} ) + { + return T::Ones(r,c); + } + + template + auto setIdentity(T&& t) requires( requires{std::forward(t).setIdentity();} ) + { + return std::forward(t).setIdentity(); + } + + template + auto setIdentity() requires( requires{T::Identity();} ) + { + return T::Identity(); + } + + template + auto setIdentity(std::integral auto n) requires( requires{T::Identity(n);} ) + { + return T::Identity(n); + } + + template + auto setIdentity(std::integral auto r,std::integral auto c) requires( requires{T::Identity(r,c);} ) + { + return T::Identity(r,c); + } + + template + auto setRandom(T&& t) requires( requires{std::forward(t).setRandom();} ) + { + return std::forward(t).setRandom(); + } + + template + auto setRandom() requires( requires{T::Random();} ) + { + return T::Random(); + } + + template + auto setRandom(std::integral auto n) requires( requires{T::Random(n);} ) + { + return T::Random(n); + } + + template + auto setRandom(std::integral auto r,std::integral auto c) requires( requires{T::Random(r,c);} ) + { + return T::Random(r,c); + } + + template + auto setConstant(T&& t, auto v) requires( requires{std::forward(t).setConstant(v);} ) + { + return std::forward(t).setConstant(v); + } + + template + auto setConstant(auto v) requires( requires{T::Constant(v);} ) + { + return T::Constant(v); + } + + template + auto setConstant(std::integral auto n, auto v) requires( requires{T::Constant(n,v);} ) + { + return T::Constant(n,v); + } + + template + auto setConstant(std::integral auto r,std::integral auto c, auto v) requires( requires{T::Constant(r,c,v);} ) + { + return T::Constant(r,c,v); + } +} \ No newline at end of file diff --git a/test/unit/matrix/generators.cpp b/test/unit/matrix/generators.cpp index 06f4827..3b3a758 100644 --- a/test/unit/matrix/generators.cpp +++ b/test/unit/matrix/generators.cpp @@ -38,99 +38,104 @@ void test_identity(const auto& matrix, std::size_t rows, std::size_t cols) TTS_CASE_TPL("Test zero", rotgen::tests::types) ( tts::type< tts::types> ) { - test_value(rotgen::matrix{}.setZero(), 3, 4, 0); - test_value(rotgen::matrix{}.setZero(), 1, 1, 0); - test_value(rotgen::matrix{}.setZero(), 10, 10, 0); - test_value(rotgen::matrix(1,1).setZero(3,4), 3, 4, 0); - test_value(rotgen::matrix().setZero(7, 5), 7, 5, 0); - test_value(rotgen::matrix(9,1).setZero(9, 3), 9, 3, 0); - test_value(rotgen::matrix(1,3).setZero(2, 3), 2, 3, 0); + using namespace rotgen; + test_value(setZero >(), 3, 4, 0); + test_value(setZero >(), 1, 1, 0); + test_value(setZero >(), 10, 10, 0); + test_value(setZero>(3, 4), 3, 4, 0); + test_value(setZero >(7, 5), 7, 5, 0); + test_value(setZero >(9, 3), 9, 3, 0); + test_value(setZero >(2, 3), 2, 3, 0); - test_value(rotgen::matrix::Zero(), 3, 4, 0); - test_value(rotgen::matrix::Zero(), 1, 1, 0); - test_value(rotgen::matrix::Zero(), 10, 10, 0); - test_value(rotgen::matrix::Zero(3, 4), 3, 4, 0); - test_value(rotgen::matrix::Zero(7, 5), 7, 5, 0); - test_value(rotgen::matrix::Zero(9, 3), 9, 3, 0); - test_value(rotgen::matrix::Zero(2, 3), 2, 3, 0); + test_value(setZero(matrix{} ), 3, 4, 0); + test_value(setZero(matrix{} ), 1, 1, 0); + test_value(setZero(matrix{} ), 10, 10, 0); + test_value(setZero(matrix{3, 4}), 3, 4, 0); + test_value(setZero(matrix{7, 5} ), 7, 5, 0); + test_value(setZero(matrix{9, 3} ), 9, 3, 0); + test_value(setZero(matrix{2, 3} ), 2, 3, 0); }; TTS_CASE_TPL("Test ones", rotgen::tests::types) ( tts::type< tts::types> ) { - test_value(rotgen::matrix{}.setOnes(), 3, 4, 1); - test_value(rotgen::matrix{}.setOnes(), 1, 1, 1); - test_value(rotgen::matrix{}.setOnes(), 10, 10, 1); - test_value(rotgen::matrix(1,1).setOnes(3, 4), 3, 4, 1); - test_value(rotgen::matrix{}.setOnes(7, 5), 7, 5, 1); - test_value(rotgen::matrix(9,1).setOnes(9, 3), 9, 3, 1); - test_value(rotgen::matrix(1,3).setOnes(2, 3), 2, 3, 1); + using namespace rotgen; + test_value(setOnes >(), 3, 4, 1); + test_value(setOnes >(), 1, 1, 1); + test_value(setOnes >(), 10, 10, 1); + test_value(setOnes>(3, 4), 3, 4, 1); + test_value(setOnes >(7, 5), 7, 5, 1); + test_value(setOnes >(9, 3), 9, 3, 1); + test_value(setOnes >(2, 3), 2, 3, 1); - test_value(rotgen::matrix::Ones(), 3, 4, 1); - test_value(rotgen::matrix::Ones(), 1, 1, 1); - test_value(rotgen::matrix::Ones(), 10, 10, 1); - test_value(rotgen::matrix::Ones(3, 4), 3, 4, 1); - test_value(rotgen::matrix::Ones(7, 5), 7, 5, 1); - test_value(rotgen::matrix::Ones(9, 3), 9, 3, 1); - test_value(rotgen::matrix::Ones(2, 3), 2, 3, 1); + test_value(setOnes(matrix{} ), 3, 4, 1); + test_value(setOnes(matrix{} ), 1, 1, 1); + test_value(setOnes(matrix{} ), 10, 10, 1); + test_value(setOnes(matrix{3, 4}), 3, 4, 1); + test_value(setOnes(matrix{7, 5} ), 7, 5, 1); + test_value(setOnes(matrix{9, 3} ), 9, 3, 1); + test_value(setOnes(matrix{2, 3} ), 2, 3, 1); }; TTS_CASE_TPL("Test constant", rotgen::tests::types) ( tts::type< tts::types> ) { - test_value(rotgen::matrix{}.setConstant(5.12), 3, 8, T(5.12)); - test_value(rotgen::matrix{}.setConstant(2.2), 1, 1, T(2.2)); - test_value(rotgen::matrix{}.setConstant(13), 11, 12, T(13)); - test_value(rotgen::matrix(1,1).setConstant(2, 7, 5.6), 2, 7, T(5.6)); - test_value(rotgen::matrix{}.setConstant(2, 2, 2.0), 2, 2, T(2.0)); - test_value(rotgen::matrix(9,1).setConstant(9, 3, 1.1), 9, 3, T(1.1)); - test_value(rotgen::matrix(1,9).setConstant(5, 9, 42), 5, 9,T(42)); + using namespace rotgen; + test_value(setConstant >(T(5.12)), 3, 4, T(5.12)); + test_value(setConstant >(T(5.12)), 1, 1, T(5.12)); + test_value(setConstant >(T(5.12)), 10, 10, T(5.12)); + test_value(setConstant>(3, 4, T(5.12)), 3, 4, T(5.12)); + test_value(setConstant >(7, 5, T(5.12)), 7, 5, T(5.12)); + test_value(setConstant >(9, 3, T(5.12)), 9, 3, T(5.12)); + test_value(setConstant >(2, 3, T(5.12)), 2, 3, T(5.12)); - test_value(rotgen::matrix::Constant(5.12), 3, 8, T(5.12)); - test_value(rotgen::matrix::Constant(2.2), 1, 1, T(2.2)); - test_value(rotgen::matrix::Constant(13), 11, 12, T(13)); - test_value(rotgen::matrix::Constant(2, 7, 5.6), 2, 7, T(5.6)); - test_value(rotgen::matrix::Constant(2, 2, 2.0), 2, 2, T(2.0)); - test_value(rotgen::matrix::Constant(9, 3, 1.1), 9, 3, T(1.1)); - test_value(rotgen::matrix::Constant(5, 9, 42), 5, 9,T(42)); + test_value(setConstant(matrix{} , T(5.12)), 3, 4, T(5.12)); + test_value(setConstant(matrix{} , T(5.12)), 1, 1, T(5.12)); + test_value(setConstant(matrix{} , T(5.12)), 10, 10, T(5.12)); + test_value(setConstant(matrix{3, 4}, T(5.12)), 3, 4, T(5.12)); + test_value(setConstant(matrix{7, 5} , T(5.12)), 7, 5, T(5.12)); + test_value(setConstant(matrix{9, 3} , T(5.12)), 9, 3, T(5.12)); + test_value(setConstant(matrix{2, 3} , T(5.12)), 2, 3, T(5.12)); }; TTS_CASE_TPL("Test random", rotgen::tests::types) ( tts::type< tts::types> ) { - test_random(rotgen::matrix{}.setRandom(), 2, 3); - test_random(rotgen::matrix{}.setRandom(), 1, 1); - test_random(rotgen::matrix{}.setRandom(), 11, 17); - test_random(rotgen::matrix{1,1}.setRandom(7, 3), 7, 3); - test_random(rotgen::matrix{}.setRandom(2, 2), 2, 2); - test_random(rotgen::matrix{4,1}.setRandom(4, 3), 4, 3); - test_random(rotgen::matrix{1,5}.setRandom(5, 5), 5, 5); + using namespace rotgen; + test_random(setRandom >(), 3, 4); + test_random(setRandom >(), 1, 1); + test_random(setRandom >(), 10, 10); + test_random(setRandom>(3, 4), 3, 4); + test_random(setRandom >(7, 5), 7, 5); + test_random(setRandom >(9, 3), 9, 3); + test_random(setRandom >(2, 3), 2, 3); - test_random(rotgen::matrix::Random(), 2, 3); - test_random(rotgen::matrix::Random(), 1, 1); - test_random(rotgen::matrix::Random(), 11, 17); - test_random(rotgen::matrix::Random(7, 3), 7, 3); - test_random(rotgen::matrix::Random(2, 2), 2, 2); - test_random(rotgen::matrix::Random(4, 3), 4, 3); - test_random(rotgen::matrix::Random(5, 5), 5, 5); + test_random(setRandom(matrix{} ), 3, 4); + test_random(setRandom(matrix{} ), 1, 1); + test_random(setRandom(matrix{} ), 10, 10); + test_random(setRandom(matrix{3, 4}), 3, 4); + test_random(setRandom(matrix{7, 5} ), 7, 5); + test_random(setRandom(matrix{9, 3} ), 9, 3); + test_random(setRandom(matrix{2, 3} ), 2, 3); }; TTS_CASE_TPL("Test identity", rotgen::tests::types) ( tts::type< tts::types> ) { - test_identity(rotgen::matrix{}.setIdentity(), 4, 5); - test_identity(rotgen::matrix{}.setIdentity(), 1, 1); - test_identity(rotgen::matrix{}.setIdentity(), 21, 3); - test_identity(rotgen::matrix{1,1}.setIdentity(2, 7), 2, 7); - test_identity(rotgen::matrix{}.setIdentity(2, 2), 2, 2); - test_identity(rotgen::matrix{3,1}.setIdentity(3, 3), 3, 3); - test_identity(rotgen::matrix{1,11}.setIdentity(5, 11), 5, 11); + using namespace rotgen; + test_identity(setIdentity >(), 3, 4); + test_identity(setIdentity >(), 1, 1); + test_identity(setIdentity >(), 10, 10); + test_identity(setIdentity>(3, 4), 3, 4); + test_identity(setIdentity >(7, 5), 7, 5); + test_identity(setIdentity >(9, 3), 9, 3); + test_identity(setIdentity >(2, 3), 2, 3); - test_identity(rotgen::matrix::Identity(), 4, 5); - test_identity(rotgen::matrix::Identity(), 1, 1); - test_identity(rotgen::matrix::Identity(), 21, 3); - test_identity(rotgen::matrix::Identity(2, 7), 2, 7); - test_identity(rotgen::matrix::Identity(2, 2), 2, 2); - test_identity(rotgen::matrix::Identity(3, 3), 3, 3); - test_identity(rotgen::matrix::Identity(5, 11), 5, 11); + test_identity(setIdentity(matrix{} ), 3, 4); + test_identity(setIdentity(matrix{} ), 1, 1); + test_identity(setIdentity(matrix{} ), 10, 10); + test_identity(setIdentity(matrix{3, 4}), 3, 4); + test_identity(setIdentity(matrix{7, 5} ), 7, 5); + test_identity(setIdentity(matrix{9, 3} ), 9, 3); + test_identity(setIdentity(matrix{2, 3} ), 2, 3); }; \ No newline at end of file