rotgen/include/rotgen/dynamic/matrix.hpp
2025-09-18 14:31:33 +02:00

349 lines
No EOL
10 KiB
C++

//==================================================================================================
/*
ROTGEN - Runtime Overlay for Eigen
Copyright : CODE RECKONS
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once
#include <rotgen/impl/matrix.hpp>
#include <initializer_list>
#include <cassert>
namespace rotgen
{
template< typename Scalar
, int Rows = Dynamic , int Cols = Dynamic
, int Opts = detail::force_order<Rows,Cols>
, int MaxRows = Rows , int MaxCols = Cols
>
class matrix : public find_matrix<Scalar,Opts>
{
public:
using parent = find_matrix<Scalar,Opts>;
using rotgen_tag = void;
using concrete_type = matrix;
using value_type = Scalar;
static constexpr auto storage_order = Opts & 1;
static constexpr Index RowsAtCompileTime = Rows;
static constexpr Index ColsAtCompileTime = Cols;
static constexpr bool IsVectorAtCompileTime = (RowsAtCompileTime == 1) || (ColsAtCompileTime == 1);
static constexpr int Options = Opts;
static constexpr bool IsRowMajor = (Opts & RowMajor) == RowMajor;
static constexpr bool is_defined_static = false;
static constexpr bool has_static_storage = false;
matrix() : parent(Rows==-1?0:Rows,Cols==-1?0:Cols) {}
matrix(Index r, Index c) : parent(r, c)
{
if constexpr(Rows != -1) assert(r == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1) assert(c == Cols && "Mismatched between dynamic and static column size");
}
matrix(Index n) requires(IsVectorAtCompileTime && (Rows != 1 || Cols != 1))
: parent(Rows != -1 ? 1 : n, Cols != -1 ? 1 : n)
{}
matrix(Scalar v) requires(Rows == 1 && Cols == 1) : parent(1,1,{v}) {}
matrix(parent const& base) : parent(base) {}
matrix(std::initializer_list<std::initializer_list<Scalar>> init) : parent(init)
{
if constexpr(Rows != -1) assert(init.size() == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1)
{
[[maybe_unused]] std::size_t c = 0;
if(init.size()) c = init.begin()->size();
assert(c == Cols && "Mismatched between dynamic and static column size");
}
}
matrix(std::initializer_list<Scalar> init)
requires(IsVectorAtCompileTime) : parent(Rows, Cols, init)
{}
matrix(concepts::entity auto const& e) : parent(e.rows(),e.cols())
{
if constexpr(Rows != -1) assert(e.rows() == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1) assert(e.cols() == Cols && "Mismatched between dynamic and static col size");
for (rotgen::Index r = 0; r < e.rows(); ++r)
for (rotgen::Index c = 0; c < e.cols(); ++c)
(*this)(r, c) = e(r, c);
}
matrix& operator=(concepts::entity auto const& e)
{
if constexpr(Rows != -1) assert(e.rows() == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1) assert(e.cols() == Cols && "Mismatched between dynamic and static col size");
resize(e.rows(), e.cols());
for (rotgen::Index r = 0; r < e.rows(); ++r)
for (rotgen::Index c = 0; c < e.cols(); ++c)
(*this)(r, c) = e(r, c);
return *this;
}
value_type& operator[](Index i) requires(IsVectorAtCompileTime)
{
return (*this)(i);
}
value_type operator[](Index i) const requires(IsVectorAtCompileTime)
{
return (*this)(i);
}
auto evaluate() const { return *this; }
decltype(auto) noalias() const { return *this; }
decltype(auto) noalias() { return *this; }
void resize(int new_rows, int new_cols) requires(Rows == -1 && Cols == -1)
{
parent::resize(new_rows, new_cols);
}
void conservativeResize(int new_rows, int new_cols) requires(Rows == -1 && Cols == -1)
{
parent::conservativeResize(new_rows, new_cols);
}
matrix transpose() const
{
return matrix(base().transpose());
}
matrix conjugate() const
{
return matrix(base().conjugate());
}
matrix adjoint() const
{
return matrix(base().adjoint());
}
void transposeInPlace() { parent::transposeInPlace(); }
void adjointInPlace() { parent::adjointInPlace(); }
matrix cwiseAbs() const { return matrix(base().cwiseAbs()); }
matrix cwiseAbs2() const { return matrix(base().cwiseAbs2()); }
matrix cwiseInverse() const { return matrix(base().cwiseInverse()); }
matrix cwiseSqrt() const { return matrix(base().cwiseSqrt()); }
friend bool operator==(matrix const& lhs, matrix const& rhs)
{
return static_cast<parent const&>(lhs) == static_cast<parent const&>(rhs);
}
matrix& operator+=(matrix const& rhs)
{
base() += static_cast<parent const&>(rhs);
return *this;
}
matrix& operator-=(matrix const& rhs)
{
base() -= static_cast<parent const&>(rhs);
return *this;
}
matrix operator-() const
{
return matrix(base().operator-());
}
matrix& operator*=(matrix const& rhs)
{
base() *= static_cast<parent const&>(rhs);
return *this;
}
matrix& operator*=(Scalar rhs)
{
base() *= rhs;
return *this;
}
matrix& operator/=(Scalar rhs)
{
base() /= rhs;
return *this;
}
static matrix Ones() requires (Rows != -1 && Cols != -1)
{
return parent::Ones(Rows, Cols);
}
static matrix Ones(int rows, int cols)
{
if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size");
return parent::Ones(rows, cols);
}
static matrix Zero() requires (Rows != -1 && Cols != -1)
{
return parent::Zero(Rows, Cols);
}
static matrix Zero(int rows, int cols)
{
if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size");
return parent::Zero(rows, cols);
}
static matrix Constant(Scalar value) requires (Rows != -1 && Cols != -1)
{
return parent::Constant(Rows, Cols, static_cast<Scalar>(value));
}
static matrix Constant(int rows, int cols, Scalar value)
{
if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size");
return parent::Constant(rows, cols, static_cast<Scalar>(value));
}
static matrix Random() requires (Rows != -1 && Cols != -1)
{
return parent::Random(Rows, Cols);
}
static matrix Random(int rows, int cols)
{
if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size");
return parent::Random(rows, cols);
}
static matrix Identity() requires (Rows != -1 && Cols != -1)
{
return parent::Identity(Rows, Cols);
}
static matrix Identity(int rows, int cols)
{
if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size");
if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size");
return parent::Identity(rows, cols);
}
matrix& setOnes()
{
*this = parent::Ones(Rows, Cols);
return *this;
}
matrix& setOnes(int rows, int cols)
{
*this = parent::Ones(rows, cols);
return *this;
}
matrix& setZero()
{
*this = parent::Zero(Rows, Cols);
return *this;
}
matrix& setZero(int rows, int cols)
{
*this = parent::Zero(rows, cols);
return *this;
}
matrix& setConstant(Scalar value)
{
*this = parent::Constant(Rows, Cols, static_cast<Scalar>(value));
return *this;
}
matrix& setConstant(int rows, int cols, Scalar value)
{
*this = parent::Constant(rows, cols, static_cast<Scalar>(value));
return *this;
}
matrix& setRandom()
{
*this = parent::Random(Rows, Cols);
return *this;
}
matrix& setRandom(int rows, int cols)
{
*this = parent::Random(rows, cols);
return *this;
}
matrix& setIdentity()
{
*this = parent::Identity(Rows, Cols);
return *this;
}
matrix& setIdentity(int rows, int cols)
{
*this = parent::Identity(rows, cols);
return *this;
}
template<int P>
Scalar lpNorm() const
{
static_assert(P == 1 || P == 2 || P == Infinity);
return parent::lp_norm(P);
}
parent& base() { return static_cast<parent&>(*this); }
parent const& base() const { return static_cast<parent const&>(*this);; }
};
template<typename S, int R, int C, int O, int MR, int MC>
matrix<S,R,C,O,MR,MC> operator+(matrix<S,R,C,O,MR,MC> const& lhs, matrix<S,R,C,O,MR,MC> const& rhs)
{
matrix<S,R,C,O,MR,MC> that(lhs);
return that += rhs;
}
template<typename S, int R, int C, int O, int MR, int MC>
matrix<S,R,C,O,MR,MC> operator-(matrix<S,R,C,O,MR,MC> const& lhs, matrix<S,R,C,O,MR,MC> const& rhs)
{
matrix<S,R,C,O,MR,MC> that(lhs);
return that -= rhs;
}
template<typename S, int R, int C, int O, int MR, int MC>
matrix<S,R,C,O,MR,MC> operator*(matrix<S,R,C,O,MR,MC> const& lhs, matrix<S,R,C,O,MR,MC> const& rhs)
{
matrix<S,R,C,O,MR,MC> that(lhs);
return that *= rhs;
}
template<typename S, int R, int C, int O, int MR, int MC>
matrix<S,R,C,O,MR,MC> operator*(matrix<S,R,C,O,MR,MC> const& lhs, double rhs)
{
matrix<S,R,C,O,MR,MC> that(lhs);
return that *= rhs;
}
template<typename S, int R, int C, int O, int MR, int MC>
matrix<S,R,C,O,MR,MC> operator*(double lhs, matrix<S,R,C,O,MR,MC> const& rhs)
{
return rhs * lhs;
}
template<typename S, int R, int C, int O, int MR, int MC>
matrix<S,R,C,O,MR,MC> operator/(matrix<S,R,C,O,MR,MC> const& lhs, double rhs)
{
matrix<S,R,C,O,MR,MC> that(lhs);
return that /= rhs;
}
}