rotgen/include/rotgen/matrix.hpp
Karen Kaspar 114bc27901 Feat/non member functions
Co-authored-by: Karen <kkaspar@codereckons.com>
Co-authored-by: Joel FALCOU <jfalcou@codereckons.com>

See merge request oss/rotgen!7
2025-06-11 14:41:49 +02:00

220 lines
No EOL
6.6 KiB
C++

//==================================================================================================
/*
ROTGEN - Runtime Overlay for Eigen
Copyright : CODE RECKONS
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once
#include <rotgen/impl/matrix.hpp>
#include <initializer_list>
#include <cassert>
namespace rotgen
{
template< typename Scalar
, int Rows = Dynamic , int Cols = Dynamic
, int Options = ColMajor
, int MaxRows = Rows , int MaxCols = Cols
>
class matrix : public find_matrix<Scalar,Options>
{
using parent = find_matrix<Scalar,Options>;
public:
using value_type = Scalar;
matrix() : parent(Rows==-1?0:Rows,Cols==-1?0:Cols) {}
matrix(int r, int c) : parent(r, c)
{
if constexpr(Rows != -1) assert(r == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1) assert(c == Cols && "Mismatched between dynamic and static column size");
}
matrix(parent const& base) : parent(base) {}
matrix(std::initializer_list<std::initializer_list<Scalar>> init) : parent(init)
{
if constexpr(Rows != -1) assert(init.size() == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1)
{
[[maybe_unused]] std::size_t c = 0;
if(init.size()) c = init.begin()->size();
assert(c == Cols && "Mismatched between dynamic and static column size");
}
}
template<std::convertible_to<Scalar>... S>
matrix(Scalar s0,S... init)
requires((Rows == 1 && Cols == (1+sizeof...(S))) || (Cols == 1 && Rows == (1+sizeof...(S))))
: parent(Rows, Cols, {s0,static_cast<Scalar>(init)...})
{}
void resize(int new_rows, int new_cols) requires(Rows == -1 && Cols == -1)
{
parent::resize(new_rows, new_cols);
}
void conservativeResize(int new_rows, int new_cols) requires(Rows == -1 && Cols == -1)
{
parent::conservativeResize(new_rows, new_cols);
}
matrix transpose() const
{
return matrix(static_cast<parent const&>(*this).transpose());
}
matrix conjugate() const
{
return matrix(static_cast<parent const&>(*this).conjugate());
}
matrix adjoint() const
{
return matrix(static_cast<parent const&>(*this).adjoint());
}
void transposeInPlace() { parent::transposeInPlace(); }
void adjointInPlace() { parent::adjointInPlace(); }
friend bool operator==(matrix const& lhs, matrix const& rhs)
{
return static_cast<parent const&>(lhs) == static_cast<parent const&>(rhs);
}
matrix& operator+=(matrix const& rhs)
{
static_cast<parent&>(*this) += static_cast<parent const&>(rhs);
return *this;
}
matrix& operator-=(matrix const& rhs)
{
static_cast<parent&>(*this) -= static_cast<parent const&>(rhs);
return *this;
}
matrix operator-() const
{
return matrix(static_cast<parent const&>(*this).operator-());
}
matrix& operator*=(matrix const& rhs)
{
static_cast<parent&>(*this) *= static_cast<parent const&>(rhs);
return *this;
}
matrix& operator*=(Scalar rhs)
{
static_cast<parent&>(*this) *= rhs;
return *this;
}
matrix& operator/=(Scalar rhs)
{
static_cast<parent&>(*this) /= rhs;
return *this;
}
static matrix Zero() requires (Rows != -1 && Cols != -1)
{
return parent::Zero(Rows, Cols);
}
static matrix Zero(int rows, int cols)
{
if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size");
return parent::Zero(rows, cols);
}
static matrix Constant(Scalar value) requires (Rows != -1 && Cols != -1)
{
return parent::Constant(Rows, Cols, static_cast<Scalar>(value));
}
static matrix Constant(int rows, int cols, Scalar value)
{
if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size");
return parent::Constant(rows, cols, static_cast<Scalar>(value));
}
static matrix Random() requires (Rows != -1 && Cols != -1)
{
return parent::Random(Rows, Cols);
}
static matrix Random(int rows, int cols)
{
if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size");
return parent::Random(rows, cols);
}
static matrix Identity() requires (Rows != -1 && Cols != -1)
{
return parent::Identity(Rows, Cols);
}
static matrix Identity(int rows, int cols)
{
if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size");
return parent::Identity(rows, cols);
}
template<int P>
Scalar lp_norm() const
{
static_assert(P == 1 || P == 2 || P == Infinity);
return parent::lp_norm(P);
}
};
template<typename S, int R, int C, int O, int MR, int MC>
matrix<S,R,C,O,MR,MC> operator+(matrix<S,R,C,O,MR,MC> const& lhs, matrix<S,R,C,O,MR,MC> const& rhs)
{
matrix<S,R,C,O,MR,MC> that(lhs);
return that += rhs;
}
template<typename S, int R, int C, int O, int MR, int MC>
matrix<S,R,C,O,MR,MC> operator-(matrix<S,R,C,O,MR,MC> const& lhs, matrix<S,R,C,O,MR,MC> const& rhs)
{
matrix<S,R,C,O,MR,MC> that(lhs);
return that -= rhs;
}
template<typename S, int R, int C, int O, int MR, int MC>
matrix<S,R,C,O,MR,MC> operator*(matrix<S,R,C,O,MR,MC> const& lhs, matrix<S,R,C,O,MR,MC> const& rhs)
{
matrix<S,R,C,O,MR,MC> that(lhs);
return that *= rhs;
}
template<typename S, int R, int C, int O, int MR, int MC>
matrix<S,R,C,O,MR,MC> operator*(matrix<S,R,C,O,MR,MC> const& lhs, S rhs)
{
matrix<S,R,C,O,MR,MC> that(lhs);
return that *= rhs;
}
template<typename S, int R, int C, int O, int MR, int MC>
matrix<S,R,C,O,MR,MC> operator*(S lhs, matrix<S,R,C,O,MR,MC> const& rhs)
{
return rhs * lhs;
}
template<typename S, int R, int C, int O, int MR, int MC>
matrix<S,R,C,O,MR,MC> operator/(matrix<S,R,C,O,MR,MC> const& lhs, S rhs)
{
matrix<S,R,C,O,MR,MC> that(lhs);
return that /= rhs;
}
}