From ddf8816c5ba47ce8c4760fa6c764108ba815dcc7 Mon Sep 17 00:00:00 2001 From: Joel Falcou Date: Mon, 29 Sep 2025 18:58:12 +0200 Subject: [PATCH] Implement dot See merge request oss/rotgen!31 --- include/rotgen/common/ref.hpp | 6 ++++ include/rotgen/detail/helpers.hpp | 16 ++++++++-- include/rotgen/dynamic/block.hpp | 2 +- include/rotgen/dynamic/map.hpp | 8 ++++- include/rotgen/fixed/map.hpp | 6 ++++ include/rotgen/functions.hpp | 10 ++++++ include/rotgen/impl/map_model.hpp | 6 ++-- src/map_model.cpp | 10 ++++++ test/unit/block/arithmetic_functions.cpp | 38 ++++++++++++++++++++++ test/unit/common/arithmetic.hpp | 39 ----------------------- test/unit/map/arithmetic_functions.cpp | 36 +++++++++++++++++++++ test/unit/matrix/arithmetic_functions.cpp | 34 +++++++++++++++++++- 12 files changed, 165 insertions(+), 46 deletions(-) diff --git a/include/rotgen/common/ref.hpp b/include/rotgen/common/ref.hpp index 6611d47..72e4822 100644 --- a/include/rotgen/common/ref.hpp +++ b/include/rotgen/common/ref.hpp @@ -280,6 +280,12 @@ namespace rotgen return lhs.base() / s; } + template + auto dot(ref lhs, ref rhs) + { + return lhs.base().dot(rhs.base()); + } + template auto min(ref lhs, ref rhs) -> decltype(lhs.base().cwiseMin(rhs.base())) { diff --git a/include/rotgen/detail/helpers.hpp b/include/rotgen/detail/helpers.hpp index c84339f..6f267d2 100644 --- a/include/rotgen/detail/helpers.hpp +++ b/include/rotgen/detail/helpers.hpp @@ -9,9 +9,21 @@ namespace rotgen::detail { + template + inline constexpr bool has_same_vector_size = []() + { + // No vector = noo size + if(!(M::IsVectorAtCompileTime && N::IsVectorAtCompileTime)) return false; + // Row vectors -> same Cols + if(M::RowsAtCompileTime == 1 && N::RowsAtCompileTime == 1) return M::ColsAtCompileTime == N::ColsAtCompileTime; + // Col vectors -> same Rows + if(M::ColsAtCompileTime == 1 && N::ColsAtCompileTime == 1) return M::RowsAtCompileTime == N::RowsAtCompileTime; + // Mixing 1xN with Mx1 + return false; + }(); + template - inline constexpr auto select_static = (M==rotgen::Dynamic || N==rotgen::Dynamic) - ? rotgen::Dynamic : M; + inline constexpr auto select_static = (M==rotgen::Dynamic || N==rotgen::Dynamic) ? rotgen::Dynamic : M; template using composite_matrix_type = matrix< typename M1::value_type diff --git a/include/rotgen/dynamic/block.hpp b/include/rotgen/dynamic/block.hpp index ac04dd5..74eb41d 100644 --- a/include/rotgen/dynamic/block.hpp +++ b/include/rotgen/dynamic/block.hpp @@ -315,7 +315,7 @@ namespace rotgen } template - double lpNorm() const + value_type lpNorm() const { assert(P == 1 || P == 2 || P == Infinity); return parent::lpNorm(P); diff --git a/include/rotgen/dynamic/map.hpp b/include/rotgen/dynamic/map.hpp index 490e506..80230db 100644 --- a/include/rotgen/dynamic/map.hpp +++ b/include/rotgen/dynamic/map.hpp @@ -313,8 +313,14 @@ namespace rotgen return *this; } + template + value_type dot(map const& rhs) const + { + return base().dot(rhs.base()); + } + template - double lpNorm() const + value_type lpNorm() const { assert(P == 1 || P == 2 || P == Infinity); return parent::lpNorm(P); diff --git a/include/rotgen/fixed/map.hpp b/include/rotgen/fixed/map.hpp index 207b35c..717d5a5 100644 --- a/include/rotgen/fixed/map.hpp +++ b/include/rotgen/fixed/map.hpp @@ -356,6 +356,12 @@ namespace rotgen return result; } + template + value_type dot(map const& rhs) const + { + return base().dot(rhs.base()); + } + template value_type lpNorm() const { static_assert(P == 1 || P == 2 || P == Infinity); diff --git a/include/rotgen/functions.hpp b/include/rotgen/functions.hpp index 54c58d0..80300e0 100644 --- a/include/rotgen/functions.hpp +++ b/include/rotgen/functions.hpp @@ -7,6 +7,8 @@ //================================================================================================== #pragma once +#include + namespace rotgen { //----------------------------------------------------------------------------------------------- @@ -142,6 +144,14 @@ namespace rotgen auto prod(concepts::entity auto const& arg) { return arg.prod(); } auto mean(concepts::entity auto const& arg) { return arg.mean(); } + template + auto dot(A const& a, B const& b) + requires(detail::has_same_vector_size && std::same_as) + { + if constexpr(!use_expression_templates) return dot(generalize_t(a), generalize_t(b)); + else return base_of(a).dot(base_of(b)); + } + auto maxCoeff(auto const& arg) requires( requires{ arg.maxCoeff(); } ) { return arg.maxCoeff(); } auto minCoeff(auto const& arg) requires( requires{ arg.minCoeff(); } ) { return arg.minCoeff(); } diff --git a/include/rotgen/impl/map_model.hpp b/include/rotgen/impl/map_model.hpp index c8d6566..41c4c8c 100644 --- a/include/rotgen/impl/map_model.hpp +++ b/include/rotgen/impl/map_model.hpp @@ -63,8 +63,10 @@ class ROTGEN_EXPORT CLASSNAME TYPE trace() const; TYPE maxCoeff() const; TYPE minCoeff() const; - TYPE maxCoeff(Index* row, Index* col) const; - TYPE minCoeff(Index* row, Index* col) const; + TYPE maxCoeff(Index*, Index*) const; + TYPE minCoeff(Index*, Index*) const; + TYPE dot(CLASSNAME const&) const; + TYPE dot(TRANSCLASSNAME const&) const; TYPE squaredNorm() const; TYPE norm() const; diff --git a/src/map_model.cpp b/src/map_model.cpp index f43b3c8..56e1a7f 100644 --- a/src/map_model.cpp +++ b/src/map_model.cpp @@ -189,6 +189,16 @@ } #endif + TYPE CLASSNAME::dot(CLASSNAME const& rhs) const + { + return storage_->data.reshaped().dot(rhs.storage()->data.reshaped()); + } + + TYPE CLASSNAME::dot(TRANSCLASSNAME const& rhs) const + { + return storage_->data.reshaped().dot(rhs.storage()->data.reshaped()); + } + TYPE CLASSNAME::sum() const { return storage_->data.sum(); } TYPE CLASSNAME::prod() const { return storage_->data.prod(); } TYPE CLASSNAME::mean() const { return storage_->data.mean(); } diff --git a/test/unit/block/arithmetic_functions.cpp b/test/unit/block/arithmetic_functions.cpp index 61a0860..c40bb86 100644 --- a/test/unit/block/arithmetic_functions.cpp +++ b/test/unit/block/arithmetic_functions.cpp @@ -72,3 +72,41 @@ TTS_CASE_TPL("Test static block reduction-like operations", rotgen::tests::types std::apply([&](auto const&... d) { (process(d),...);}, cases); }; + + + +TTS_CASE_TPL("Test dot product", float, double) +( tts::type ) +{ + { + auto v = rotgen::setConstant>(1,16,2); + auto a = rotgen::head(v,8); + auto b = rotgen::tail(v,8); + + TTS_EQUAL(rotgen::dot(a,b), 32); + } + + { + auto v = rotgen::setConstant>(16,1,2); + auto a = rotgen::head(v,8); + auto b = rotgen::tail(v,8); + + TTS_EQUAL(rotgen::dot(a,b), 32); + } + + { + auto v = rotgen::setConstant>(1,16,2); + auto a = rotgen::head<8>(v); + auto b = rotgen::tail<8>(v); + + TTS_EQUAL(rotgen::dot(a,b), 32); + } + + { + auto v = rotgen::setConstant>(16,1,2); + auto a = rotgen::head<8>(v); + auto b = rotgen::tail<8>(v); + + TTS_EQUAL(rotgen::dot(a,b), 32); + } +}; \ No newline at end of file diff --git a/test/unit/common/arithmetic.hpp b/test/unit/common/arithmetic.hpp index 937d3a1..ddd2ccd 100644 --- a/test/unit/common/arithmetic.hpp +++ b/test/unit/common/arithmetic.hpp @@ -20,10 +20,6 @@ namespace rotgen::tests mat_t result(original.cols(), original.rows()); prepare([&](auto r, auto c) { return original(c,r); },result); - TTS_EQUAL(original.transpose(), result ); - TTS_EQUAL(original.conjugate(), original); - TTS_EQUAL(original.adjoint() , result ); - TTS_EQUAL(transpose(original) , result ); TTS_EQUAL(conjugate(original) , original); TTS_EQUAL(adjoint(original) , result ); @@ -33,12 +29,6 @@ namespace rotgen::tests if constexpr(T::RowsAtCompileTime == T::ColsAtCompileTime) { mat_t ref = original; - original.transposeInPlace(); - TTS_EQUAL(original, result); - - original.adjointInPlace(); - TTS_EQUAL(original, ref); - transposeInPlace(original); TTS_EQUAL(original, result); @@ -51,12 +41,6 @@ namespace rotgen::tests if (original.rows() == original.cols()) { mat_t ref = original; - original.transposeInPlace(); - TTS_EQUAL(original, result); - - original.adjointInPlace(); - TTS_EQUAL(original, ref); - transposeInPlace(original); TTS_EQUAL(original, result); @@ -81,36 +65,13 @@ namespace rotgen::tests EigenMatrix ref(input.rows(), input.cols()); prepare([&](auto r, auto c) { return input(r,c); }, ref); - TTS_ULP_EQUAL(input.sum(), ref.sum() , 2); TTS_ULP_EQUAL(sum(input) , ref.sum() , 2); - - TTS_ULP_EQUAL(input.prod(), ref.prod(), 2); TTS_ULP_EQUAL(prod(input) , ref.prod(), 2); - - TTS_ULP_EQUAL(input.mean(), ref.mean(), 2); TTS_ULP_EQUAL(mean(input) , ref.mean(), 2); - - TTS_EQUAL(input.trace(), ref.trace()); TTS_EQUAL(trace(input) , ref.trace()); - - TTS_EQUAL(input.minCoeff(), ref.minCoeff()); TTS_EQUAL(minCoeff(input) , ref.minCoeff()); - - TTS_EQUAL(input.maxCoeff(), ref.maxCoeff()); TTS_EQUAL(maxCoeff(input) , ref.maxCoeff()); - { - int row, col, ref_row, ref_col; - - TTS_EQUAL(input.minCoeff(&row, &col), ref.minCoeff(&ref_row, &ref_col)); - TTS_EQUAL(row, ref_row); - TTS_EQUAL(col, ref_col); - - TTS_EQUAL(input.maxCoeff(&row, &col), ref.maxCoeff(&ref_row, &ref_col)); - TTS_EQUAL(row, ref_row); - TTS_EQUAL(col, ref_col); - } - { int row, col, ref_row, ref_col; diff --git a/test/unit/map/arithmetic_functions.cpp b/test/unit/map/arithmetic_functions.cpp index 1837bef..4257c47 100644 --- a/test/unit/map/arithmetic_functions.cpp +++ b/test/unit/map/arithmetic_functions.cpp @@ -70,4 +70,40 @@ TTS_CASE_TPL("Test static map reduction-like operations", rotgen::tests::types) }; std::apply([&](auto const&... d) { (process(d),...);}, cases); +}; + +TTS_CASE_TPL("Test dot product", float, double) +( tts::type ) +{ + { + auto v = rotgen::setConstant>(1,16,2); + auto a = rotgen::map(v.data(),1,8); + auto b = rotgen::map(v.data()+8,1,8); + + TTS_EQUAL(rotgen::dot(a,b), 32); + } + + { + auto v = rotgen::setConstant>(16,1,2); + auto a = rotgen::map(v.data(),8,1); + auto b = rotgen::map(v.data()+8,8,1); + + TTS_EQUAL(rotgen::dot(a,b), 32); + } + + { + auto v = rotgen::setConstant>(1,16,2); + auto a = rotgen::map>(v.data()); + auto b = rotgen::map>(v.data()+8); + + TTS_EQUAL(rotgen::dot(a,b), 32); + } + + { + auto v = rotgen::setConstant>(1,16,2); + auto a = rotgen::map>(v.data()); + auto b = rotgen::map>(v.data()+8); + + TTS_EQUAL(rotgen::dot(a,b), 32); + } }; \ No newline at end of file diff --git a/test/unit/matrix/arithmetic_functions.cpp b/test/unit/matrix/arithmetic_functions.cpp index 05eb229..9796350 100644 --- a/test/unit/matrix/arithmetic_functions.cpp +++ b/test/unit/matrix/arithmetic_functions.cpp @@ -61,4 +61,36 @@ TTS_CASE_TPL("Test static matrix reduction-like operations", rotgen::tests::type }; std::apply([&](auto const&... d) { (process(d),...);}, cases); -}; \ No newline at end of file +}; + +TTS_CASE_TPL("Test dot product", float, double) +( tts::type ) +{ + { + auto a = rotgen::setConstant>(1,8,2); + auto b = rotgen::setConstant>(1,8,2); + + TTS_EQUAL(rotgen::dot(a,b), 32); + } + + { + auto a = rotgen::setConstant>(8,1,2); + auto b = rotgen::setConstant>(8,1,2); + + TTS_EQUAL(rotgen::dot(a,b), 32); + } + + { + auto a = rotgen::setConstant>(2); + auto b = rotgen::setConstant>(2); + + TTS_EQUAL(rotgen::dot(a,b), 32); + } + + { + auto a = rotgen::setConstant>(2); + auto b = rotgen::setConstant>(2); + + TTS_EQUAL(rotgen::dot(a,b), 32); + } +};