Implement dot

See merge request oss/rotgen!31
This commit is contained in:
Joel Falcou 2025-09-29 18:58:12 +02:00
parent 3313e257c8
commit ddf8816c5b
12 changed files with 165 additions and 46 deletions

View file

@ -280,6 +280,12 @@ namespace rotgen
return lhs.base() / s; return lhs.base() / s;
} }
template<typename A, int O, typename S, typename B, int P, typename T>
auto dot(ref<A,O,S> lhs, ref<B,P,T> rhs)
{
return lhs.base().dot(rhs.base());
}
template<typename A, int O, typename S, typename B, int P, typename T> template<typename A, int O, typename S, typename B, int P, typename T>
auto min(ref<A,O,S> lhs, ref<B,P,T> rhs) -> decltype(lhs.base().cwiseMin(rhs.base())) auto min(ref<A,O,S> lhs, ref<B,P,T> rhs) -> decltype(lhs.base().cwiseMin(rhs.base()))
{ {

View file

@ -9,9 +9,21 @@
namespace rotgen::detail namespace rotgen::detail
{ {
template<typename M, typename N>
inline constexpr bool has_same_vector_size = []()
{
// No vector = noo size
if(!(M::IsVectorAtCompileTime && N::IsVectorAtCompileTime)) return false;
// Row vectors -> same Cols
if(M::RowsAtCompileTime == 1 && N::RowsAtCompileTime == 1) return M::ColsAtCompileTime == N::ColsAtCompileTime;
// Col vectors -> same Rows
if(M::ColsAtCompileTime == 1 && N::ColsAtCompileTime == 1) return M::RowsAtCompileTime == N::RowsAtCompileTime;
// Mixing 1xN with Mx1
return false;
}();
template<auto M, auto N> template<auto M, auto N>
inline constexpr auto select_static = (M==rotgen::Dynamic || N==rotgen::Dynamic) inline constexpr auto select_static = (M==rotgen::Dynamic || N==rotgen::Dynamic) ? rotgen::Dynamic : M;
? rotgen::Dynamic : M;
template<typename M1, typename M2> template<typename M1, typename M2>
using composite_matrix_type = matrix< typename M1::value_type using composite_matrix_type = matrix< typename M1::value_type

View file

@ -315,7 +315,7 @@ namespace rotgen
} }
template<int P> template<int P>
double lpNorm() const value_type lpNorm() const
{ {
assert(P == 1 || P == 2 || P == Infinity); assert(P == 1 || P == 2 || P == Infinity);
return parent::lpNorm(P); return parent::lpNorm(P);

View file

@ -313,8 +313,14 @@ namespace rotgen
return *this; return *this;
} }
template<typename R2, int O2, typename S2>
value_type dot(map<R2,O2,S2> const& rhs) const
{
return base().dot(rhs.base());
}
template<int P> template<int P>
double lpNorm() const value_type lpNorm() const
{ {
assert(P == 1 || P == 2 || P == Infinity); assert(P == 1 || P == 2 || P == Infinity);
return parent::lpNorm(P); return parent::lpNorm(P);

View file

@ -356,6 +356,12 @@ namespace rotgen
return result; return result;
} }
template<typename R2, int O2, typename S2>
value_type dot(map<R2,O2,S2> const& rhs) const
{
return base().dot(rhs.base());
}
template<int P> value_type lpNorm() const template<int P> value_type lpNorm() const
{ {
static_assert(P == 1 || P == 2 || P == Infinity); static_assert(P == 1 || P == 2 || P == Infinity);

View file

@ -7,6 +7,8 @@
//================================================================================================== //==================================================================================================
#pragma once #pragma once
#include <rotgen/detail/helpers.hpp>
namespace rotgen namespace rotgen
{ {
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
@ -142,6 +144,14 @@ namespace rotgen
auto prod(concepts::entity auto const& arg) { return arg.prod(); } auto prod(concepts::entity auto const& arg) { return arg.prod(); }
auto mean(concepts::entity auto const& arg) { return arg.mean(); } auto mean(concepts::entity auto const& arg) { return arg.mean(); }
template<concepts::entity A, concepts::entity B>
auto dot(A const& a, B const& b)
requires(detail::has_same_vector_size<A,B> && std::same_as<typename A::value_type, typename B::value_type>)
{
if constexpr(!use_expression_templates) return dot(generalize_t<A const>(a), generalize_t<B const>(b));
else return base_of(a).dot(base_of(b));
}
auto maxCoeff(auto const& arg) requires( requires{ arg.maxCoeff(); } ) { return arg.maxCoeff(); } auto maxCoeff(auto const& arg) requires( requires{ arg.maxCoeff(); } ) { return arg.maxCoeff(); }
auto minCoeff(auto const& arg) requires( requires{ arg.minCoeff(); } ) { return arg.minCoeff(); } auto minCoeff(auto const& arg) requires( requires{ arg.minCoeff(); } ) { return arg.minCoeff(); }

View file

@ -63,8 +63,10 @@ class ROTGEN_EXPORT CLASSNAME
TYPE trace() const; TYPE trace() const;
TYPE maxCoeff() const; TYPE maxCoeff() const;
TYPE minCoeff() const; TYPE minCoeff() const;
TYPE maxCoeff(Index* row, Index* col) const; TYPE maxCoeff(Index*, Index*) const;
TYPE minCoeff(Index* row, Index* col) const; TYPE minCoeff(Index*, Index*) const;
TYPE dot(CLASSNAME const&) const;
TYPE dot(TRANSCLASSNAME const&) const;
TYPE squaredNorm() const; TYPE squaredNorm() const;
TYPE norm() const; TYPE norm() const;

View file

@ -189,6 +189,16 @@
} }
#endif #endif
TYPE CLASSNAME::dot(CLASSNAME const& rhs) const
{
return storage_->data.reshaped().dot(rhs.storage()->data.reshaped());
}
TYPE CLASSNAME::dot(TRANSCLASSNAME const& rhs) const
{
return storage_->data.reshaped().dot(rhs.storage()->data.reshaped());
}
TYPE CLASSNAME::sum() const { return storage_->data.sum(); } TYPE CLASSNAME::sum() const { return storage_->data.sum(); }
TYPE CLASSNAME::prod() const { return storage_->data.prod(); } TYPE CLASSNAME::prod() const { return storage_->data.prod(); }
TYPE CLASSNAME::mean() const { return storage_->data.mean(); } TYPE CLASSNAME::mean() const { return storage_->data.mean(); }

View file

@ -72,3 +72,41 @@ TTS_CASE_TPL("Test static block reduction-like operations", rotgen::tests::types
std::apply([&](auto const&... d) { (process(d),...);}, cases); std::apply([&](auto const&... d) { (process(d),...);}, cases);
}; };
TTS_CASE_TPL("Test dot product", float, double)
<typename T>( tts::type<T> )
{
{
auto v = rotgen::setConstant<rotgen::matrix<T,1,rotgen::Dynamic>>(1,16,2);
auto a = rotgen::head(v,8);
auto b = rotgen::tail(v,8);
TTS_EQUAL(rotgen::dot(a,b), 32);
}
{
auto v = rotgen::setConstant<rotgen::matrix<T,rotgen::Dynamic,1>>(16,1,2);
auto a = rotgen::head(v,8);
auto b = rotgen::tail(v,8);
TTS_EQUAL(rotgen::dot(a,b), 32);
}
{
auto v = rotgen::setConstant<rotgen::matrix<T,1,rotgen::Dynamic>>(1,16,2);
auto a = rotgen::head<8>(v);
auto b = rotgen::tail<8>(v);
TTS_EQUAL(rotgen::dot(a,b), 32);
}
{
auto v = rotgen::setConstant<rotgen::matrix<T,rotgen::Dynamic,1>>(16,1,2);
auto a = rotgen::head<8>(v);
auto b = rotgen::tail<8>(v);
TTS_EQUAL(rotgen::dot(a,b), 32);
}
};

View file

@ -20,10 +20,6 @@ namespace rotgen::tests
mat_t result(original.cols(), original.rows()); mat_t result(original.cols(), original.rows());
prepare([&](auto r, auto c) { return original(c,r); },result); prepare([&](auto r, auto c) { return original(c,r); },result);
TTS_EQUAL(original.transpose(), result );
TTS_EQUAL(original.conjugate(), original);
TTS_EQUAL(original.adjoint() , result );
TTS_EQUAL(transpose(original) , result ); TTS_EQUAL(transpose(original) , result );
TTS_EQUAL(conjugate(original) , original); TTS_EQUAL(conjugate(original) , original);
TTS_EQUAL(adjoint(original) , result ); TTS_EQUAL(adjoint(original) , result );
@ -33,12 +29,6 @@ namespace rotgen::tests
if constexpr(T::RowsAtCompileTime == T::ColsAtCompileTime) if constexpr(T::RowsAtCompileTime == T::ColsAtCompileTime)
{ {
mat_t ref = original; mat_t ref = original;
original.transposeInPlace();
TTS_EQUAL(original, result);
original.adjointInPlace();
TTS_EQUAL(original, ref);
transposeInPlace(original); transposeInPlace(original);
TTS_EQUAL(original, result); TTS_EQUAL(original, result);
@ -51,12 +41,6 @@ namespace rotgen::tests
if (original.rows() == original.cols()) if (original.rows() == original.cols())
{ {
mat_t ref = original; mat_t ref = original;
original.transposeInPlace();
TTS_EQUAL(original, result);
original.adjointInPlace();
TTS_EQUAL(original, ref);
transposeInPlace(original); transposeInPlace(original);
TTS_EQUAL(original, result); TTS_EQUAL(original, result);
@ -81,36 +65,13 @@ namespace rotgen::tests
EigenMatrix ref(input.rows(), input.cols()); EigenMatrix ref(input.rows(), input.cols());
prepare([&](auto r, auto c) { return input(r,c); }, ref); prepare([&](auto r, auto c) { return input(r,c); }, ref);
TTS_ULP_EQUAL(input.sum(), ref.sum() , 2);
TTS_ULP_EQUAL(sum(input) , ref.sum() , 2); TTS_ULP_EQUAL(sum(input) , ref.sum() , 2);
TTS_ULP_EQUAL(input.prod(), ref.prod(), 2);
TTS_ULP_EQUAL(prod(input) , ref.prod(), 2); TTS_ULP_EQUAL(prod(input) , ref.prod(), 2);
TTS_ULP_EQUAL(input.mean(), ref.mean(), 2);
TTS_ULP_EQUAL(mean(input) , ref.mean(), 2); TTS_ULP_EQUAL(mean(input) , ref.mean(), 2);
TTS_EQUAL(input.trace(), ref.trace());
TTS_EQUAL(trace(input) , ref.trace()); TTS_EQUAL(trace(input) , ref.trace());
TTS_EQUAL(input.minCoeff(), ref.minCoeff());
TTS_EQUAL(minCoeff(input) , ref.minCoeff()); TTS_EQUAL(minCoeff(input) , ref.minCoeff());
TTS_EQUAL(input.maxCoeff(), ref.maxCoeff());
TTS_EQUAL(maxCoeff(input) , ref.maxCoeff()); TTS_EQUAL(maxCoeff(input) , ref.maxCoeff());
{
int row, col, ref_row, ref_col;
TTS_EQUAL(input.minCoeff(&row, &col), ref.minCoeff(&ref_row, &ref_col));
TTS_EQUAL(row, ref_row);
TTS_EQUAL(col, ref_col);
TTS_EQUAL(input.maxCoeff(&row, &col), ref.maxCoeff(&ref_row, &ref_col));
TTS_EQUAL(row, ref_row);
TTS_EQUAL(col, ref_col);
}
{ {
int row, col, ref_row, ref_col; int row, col, ref_row, ref_col;

View file

@ -71,3 +71,39 @@ TTS_CASE_TPL("Test static map reduction-like operations", rotgen::tests::types)
std::apply([&](auto const&... d) { (process(d),...);}, cases); std::apply([&](auto const&... d) { (process(d),...);}, cases);
}; };
TTS_CASE_TPL("Test dot product", float, double)
<typename T>( tts::type<T> )
{
{
auto v = rotgen::setConstant<rotgen::matrix<T,1,rotgen::Dynamic>>(1,16,2);
auto a = rotgen::map<decltype(v)>(v.data(),1,8);
auto b = rotgen::map<decltype(v)>(v.data()+8,1,8);
TTS_EQUAL(rotgen::dot(a,b), 32);
}
{
auto v = rotgen::setConstant<rotgen::matrix<T,rotgen::Dynamic,1>>(16,1,2);
auto a = rotgen::map<decltype(v)>(v.data(),8,1);
auto b = rotgen::map<decltype(v)>(v.data()+8,8,1);
TTS_EQUAL(rotgen::dot(a,b), 32);
}
{
auto v = rotgen::setConstant<rotgen::matrix<T,1,rotgen::Dynamic>>(1,16,2);
auto a = rotgen::map<rotgen::matrix<T,1,8>>(v.data());
auto b = rotgen::map<rotgen::matrix<T,1,8>>(v.data()+8);
TTS_EQUAL(rotgen::dot(a,b), 32);
}
{
auto v = rotgen::setConstant<rotgen::matrix<T,1,rotgen::Dynamic>>(1,16,2);
auto a = rotgen::map<rotgen::matrix<T,8,1>>(v.data());
auto b = rotgen::map<rotgen::matrix<T,8,1>>(v.data()+8);
TTS_EQUAL(rotgen::dot(a,b), 32);
}
};

View file

@ -62,3 +62,35 @@ TTS_CASE_TPL("Test static matrix reduction-like operations", rotgen::tests::type
std::apply([&](auto const&... d) { (process(d),...);}, cases); std::apply([&](auto const&... d) { (process(d),...);}, cases);
}; };
TTS_CASE_TPL("Test dot product", float, double)
<typename T>( tts::type<T> )
{
{
auto a = rotgen::setConstant<rotgen::matrix<T,1,rotgen::Dynamic>>(1,8,2);
auto b = rotgen::setConstant<rotgen::matrix<T,1,rotgen::Dynamic>>(1,8,2);
TTS_EQUAL(rotgen::dot(a,b), 32);
}
{
auto a = rotgen::setConstant<rotgen::matrix<T,rotgen::Dynamic,1>>(8,1,2);
auto b = rotgen::setConstant<rotgen::matrix<T,rotgen::Dynamic,1>>(8,1,2);
TTS_EQUAL(rotgen::dot(a,b), 32);
}
{
auto a = rotgen::setConstant<rotgen::matrix<T,1,8>>(2);
auto b = rotgen::setConstant<rotgen::matrix<T,1,8>>(2);
TTS_EQUAL(rotgen::dot(a,b), 32);
}
{
auto a = rotgen::setConstant<rotgen::matrix<T,8,1>>(2);
auto b = rotgen::setConstant<rotgen::matrix<T,8,1>>(2);
TTS_EQUAL(rotgen::dot(a,b), 32);
}
};