Implement a QR solver wrapper

This commit is contained in:
Joel Falcou 2025-09-09 16:27:22 +02:00
parent bb5d739e5d
commit bb47b07422
10 changed files with 90 additions and 6 deletions

View file

@ -241,6 +241,11 @@ namespace rotgen
parent& base() { return static_cast<parent&>(*this); } parent& base() { return static_cast<parent&>(*this); }
parent const& base() const { return static_cast<parent const&>(*this); } parent const& base() const { return static_cast<parent const&>(*this); }
concrete_type qr_solve(map const& rhs) const
{
return concrete_type(base().qr_solve(rhs.base()));
};
}; };
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2> template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>

View file

@ -193,6 +193,11 @@ namespace rotgen
void transposeInPlace() { base().transposeInPlace(); } void transposeInPlace() { base().transposeInPlace(); }
void adjointInPlace() { base().adjointInPlace(); } void adjointInPlace() { base().adjointInPlace(); }
auto qr_solve(map const& rhs) const
{
return concrete_type(base().colPivHouseholderQr().solve(rhs.base()));
};
static auto Zero() requires( requires {Ref::Zero();} ) { return Ref::Zero(); } static auto Zero() requires( requires {Ref::Zero();} ) { return Ref::Zero(); }
static auto Zero(int rows, int cols) { return Ref::Zero(rows,cols); } static auto Zero(int rows, int cols) { return Ref::Zero(rows,cols); }

View file

@ -52,7 +52,13 @@ namespace rotgen
auto sum(concepts::entity auto const& arg) { return arg.sum(); } auto sum(concepts::entity auto const& arg) { return arg.sum(); }
auto prod(concepts::entity auto const& arg) { return arg.prod(); } auto prod(concepts::entity auto const& arg) { return arg.prod(); }
auto mean(concepts::entity auto const& arg) { return arg.mean(); } auto mean(concepts::entity auto const& arg) { return arg.mean(); }
auto maxCoeff(concepts::entity auto const& arg) { return arg.maxCoeff(); }
auto maxCoeff(auto const& arg)
requires( requires{ arg.maxCoeff(); } )
{
return arg.maxCoeff();
}
auto minCoeff(concepts::entity auto const& arg) { return arg.minCoeff(); } auto minCoeff(concepts::entity auto const& arg) { return arg.minCoeff(); }
auto maxCoeff(concepts::entity auto const& arg, Index* row, Index* col) auto maxCoeff(concepts::entity auto const& arg, Index* row, Index* col)
{ {

View file

@ -63,6 +63,8 @@ class ROTGEN_EXPORT CLASSNAME
TYPE norm() const; TYPE norm() const;
TYPE lpNorm(int p) const; TYPE lpNorm(int p) const;
SOURCENAME qr_solve(CLASSNAME const& rhs) const;
#if !defined(USE_CONST) #if !defined(USE_CONST)
TYPE& operator()(Index i, Index j); TYPE& operator()(Index i, Index j);
TYPE& operator()(Index i); TYPE& operator()(Index i);

View file

@ -26,4 +26,5 @@
#include <rotgen/extract.hpp> #include <rotgen/extract.hpp>
#include <rotgen/functions.hpp> #include <rotgen/functions.hpp>
#include <rotgen/operators.hpp> #include <rotgen/operators.hpp>
#include <rotgen/solver.hpp>
#include <rotgen/alias.hpp> #include <rotgen/alias.hpp>

20
include/rotgen/solver.hpp Normal file
View file

@ -0,0 +1,20 @@
//==================================================================================================
/*
ROTGEN - Runtime Overlay for Eigen
Copyright : CODE RECKONS
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once
namespace rotgen::solver
{
template<typename X, typename M, typename RHS>
void qr(X& x, M const& m, RHS const& rhs )
{
auto r_x = generalize_t<X>(x);
auto r_m = generalize_t<M const>(m);
auto r_rhs = generalize_t<RHS const>(rhs);
r_x = r_m.base().qr_solve(r_rhs.base());
}
}

View file

@ -170,6 +170,16 @@
} }
#endif #endif
//==================================================================================================
// Solvers
//==================================================================================================
SOURCENAME CLASSNAME::qr_solve(CLASSNAME const& rhs) const
{
SOURCENAME result;
result.storage()->assign(storage_->data.colPivHouseholderQr().solve(rhs.storage_->data).eval());
return result;
}
//================================================================================================== //==================================================================================================
// Operators // Operators
//================================================================================================== //==================================================================================================

View file

@ -21,6 +21,7 @@ rotgen_glob_unit(QUIET PATTERN "unit/*.cpp" INTERFACE rotgen_test)
rotgen_glob_unit(QUIET PATTERN "unit/matrix/*.cpp" INTERFACE rotgen_test) rotgen_glob_unit(QUIET PATTERN "unit/matrix/*.cpp" INTERFACE rotgen_test)
rotgen_glob_unit(QUIET PATTERN "unit/block/*.cpp" INTERFACE rotgen_test) rotgen_glob_unit(QUIET PATTERN "unit/block/*.cpp" INTERFACE rotgen_test)
rotgen_glob_unit(QUIET PATTERN "unit/map/*.cpp" INTERFACE rotgen_test) rotgen_glob_unit(QUIET PATTERN "unit/map/*.cpp" INTERFACE rotgen_test)
rotgen_glob_unit(QUIET PATTERN "unit/functions/*.cpp" INTERFACE rotgen_test)
##====================================================================================================================== ##======================================================================================================================
## Integrations Tests registration ## Integrations Tests registration

View file

@ -0,0 +1,34 @@
//==================================================================================================
/*
ROTGEN - Runtime Overlay for Eigen
Copyright : CODE RECKONS
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#include "unit/tests.hpp"
#include <rotgen/rotgen.hpp>
TTS_CASE_TPL("System solver using QR", rotgen::tests::types)
<typename T, typename O>( tts::type< tts::types<T,O>> )
{
rotgen::matrix<T,rotgen::Dynamic,rotgen::Dynamic,O::value>
a { { 2.3, -1, 0.1}
, {-1.6, 2.6, -1}
, { 0.3, -1, 2}
};
rotgen::matrix<T,rotgen::Dynamic,1,O::value> b(3,1);
b(0) = b(2) = 1; b(1) = 0;
rotgen::matrix<T,rotgen::Dynamic,rotgen::Dynamic,O::value> r(3,1), error;
auto x = rotgen::extract(r,0,0,3,1);
rotgen::solver::qr(x, a, b);
error = a * r - b;
auto eps = std::numeric_limits<T>::epsilon();
TTS_LESS(rotgen::maxCoeff(rotgen::abs(error)) / eps, 5)
<< "Result:\n" << r << "\n"
<< "Residuals:\n" << error << "\n";
};