rotgen/include/rotgen/fixed/map.hpp
2025-08-15 16:49:33 +02:00

320 lines
11 KiB
C++

//==================================================================================================
/*
ROTGEN - Runtime Overlay for Eigen
Copyright : CODE RECKONS
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once
#include <rotgen/detail/static_info.hpp>
#include <Eigen/Dense>
#include <iostream>
namespace rotgen
{
namespace detail
{
template<typename Ref, int Options, bool isConst>
struct compute_map_type
{
using base = Eigen::Matrix< typename Ref::value_type
, Ref::RowsAtCompileTime, Ref::ColsAtCompileTime
, Ref::storage_order
>;
using ref_t = std::conditional_t<isConst, base const, base>;
using type = Eigen::Map<ref_t, Options, Eigen::Stride<-1,-1>>;
};
template<typename Ref, int Options, bool isConst>
using map_type = typename compute_map_type<Ref,Options,isConst>::type;
}
template<typename Ref, int Options = ColMajor, typename Stride = stride>
class map : private detail::map_type<std::remove_const_t<Ref>, Options, std::is_const_v<Ref>>
{
public:
using rotgen_tag = void;
using parent = detail::map_type<std::remove_const_t<Ref>, Options, std::is_const_v<Ref>>;
using value_type = typename std::remove_const_t<Ref>::value_type;
using concrete_type = typename std::remove_const_t<Ref>::concrete_type;
static constexpr auto Flags = Ref::Flags;
static constexpr Index RowsAtCompileTime = Ref::RowsAtCompileTime;
static constexpr Index ColsAtCompileTime = Ref::ColsAtCompileTime;
static constexpr bool has_static_storage = Ref::has_static_storage;
static constexpr int storage_order = Ref::storage_order;
static constexpr bool is_immutable = std::is_const_v<Ref>;
static constexpr bool is_defined_static = Ref::is_defined_static;
template<typename ET>
using as_concrete_type = as_concrete_t<ET, matrix>;
using ptr_type = std::conditional_t<is_immutable, value_type const*, value_type*>;
using stride_type = Stride;
map(const map&) = default;
map(map&&) = default;
map& operator=(const map&) = default;
map& operator=(map&&) = default;
map(ptr_type ptr, Index r, Index c, stride_type s) : parent(ptr, r, c, strides<storage_order>(s)) {}
map(ptr_type ptr, Index r, Index c) : parent(ptr, r, c, strides<storage_order>(r,c)) {}
map(ptr_type ptr, stride_type s) requires(RowsAtCompileTime!=-1 && ColsAtCompileTime!=-1)
: parent(ptr, strides<storage_order>(s))
{}
map(ptr_type ptr, Index sz) requires(RowsAtCompileTime==1 || ColsAtCompileTime==1)
: map(ptr, RowsAtCompileTime==1?1:sz, ColsAtCompileTime==1?1:sz)
{}
map(ptr_type ptr) requires(RowsAtCompileTime!=-1 && ColsAtCompileTime!=-1)
: map( ptr, RowsAtCompileTime, ColsAtCompileTime )
{}
parent& base() { return static_cast<parent&>(*this); }
parent const& base() const { return static_cast<const parent&>(*this); }
value_type& operator()(Index i, Index j) { return base()(i, j); }
value_type operator()(Index i, Index j) const { return base()(i, j); }
value_type& operator()(Index i) requires(!is_immutable)
{
assert(is_contiguous_linear());
return base().data()[i];
}
value_type operator()(Index i) const
{
assert(is_contiguous_linear());
return base().data()[i];
}
map& operator+=(map const& rhs)
{
base() += rhs.base();
return *this;
}
map& operator-=(map const& rhs)
{
base() -= rhs.base();
return *this;
}
map& operator*=(map const& rhs)
{
base() *= rhs;
return *this;
}
map& operator*=(value_type rhs)
{
base() *= rhs;
return *this;
}
map& operator/=(value_type rhs)
{
base() /= rhs;
return *this;
}
auto transpose() const
{
if constexpr(use_expression_templates) return base().transpose();
else return as_concrete_type<decltype(base().transpose())>(base().transpose());
}
auto adjoint() const
{
if constexpr(use_expression_templates) return base().adjoint();
else return as_concrete_type<decltype(base().adjoint())>(base().adjoint());
}
auto conjugate() const
{
if constexpr(use_expression_templates) return base().conjugate();
else return as_concrete_type<decltype(base().conjugate())>(base().conjugate());
}
void transposeInPlace() { base().transposeInPlace(); }
void adjointInPlace() { base().adjointInPlace(); }
static auto Zero() requires( requires {Ref::Zero();} ) { return Ref::Zero(); }
static auto Zero(int rows, int cols) { return Ref::Zero(rows,cols); }
static auto Ones() requires( requires {Ref::Ones();} ) { return Ref::Ones(); }
static auto Ones(int rows, int cols) { return Ref::Ones(rows,cols); }
static auto Constant(value_type value) requires( requires {Ref::Constant(value);} )
{ return Ref::Constant(value); }
static auto Constant(int rows, int cols, value_type value) { return Ref::Constant(rows, cols, value); }
static auto Random() requires( requires {Ref::Random();} ) { return Ref::Random(); }
static auto Random(int rows, int cols) { return Ref::Random(rows, cols); }
static auto Identity() requires( requires {Ref::Identity();} ) { return Ref::Identity(); }
static auto Identity(int rows, int cols) { return Ref::Identity(rows, cols); }
map& setOnes()
{
base() = parent::Ones(base().rows(), base().cols());
return *this;
}
map& setZero()
{
base() = parent::Zero(base().rows(), base().cols());
return *this;
}
map& setConstant(value_type value)
{
base() = parent::Constant(base().rows(), base().cols(), value);
return *this;
}
map& setRandom()
{
base() = parent::Random(base().rows(), base().cols());
return *this;
}
map& setIdentity()
{
base() = parent::Identity(base().rows(), base().cols());
return *this;
}
bool is_contiguous_linear() const
{
if(base().innerStride() != 1) return false;
if constexpr(storage_order == rotgen::RowMajor) return base().outerStride() == base().cols();
else return base().outerStride() == base().rows();
}
using parent::innerStride;
using parent::outerStride;
using parent::rows;
using parent::cols;
using parent::size;
using parent::data;
using parent::sum;
using parent::mean;
using parent::prod;
using parent::trace;
using parent::maxCoeff;
using parent::minCoeff;
using parent::norm;
using parent::squaredNorm;
template<int P> value_type lpNorm() const
{
static_assert(P == 1 || P == 2 || P == Infinity);
return parent::template lpNorm<P>();
}
};
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
bool operator==(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
return lhs.base() == rhs.base();
}
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
bool operator!=(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
return lhs.base() != rhs.base();
}
#if defined(ROTGEN_ENABLE_EXPRESSION_TEMPLATES)
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
auto operator+(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
return lhs.base() + rhs.base();
}
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
auto operator-(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
return lhs.base() - rhs.base();
}
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
auto operator*(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
return lhs.base() * rhs.base();
}
template<typename R, int O, typename S>
auto operator*( map<R, O, S> const& lhs, std::convertible_to<typename R::value_type> auto s)
{
return lhs.base() * s;
}
template<typename R, int O, typename S>
auto operator*(std::convertible_to<typename R::value_type> auto s, map<R, O, S> const& rhs)
{
return s * rhs.base();
}
template<typename R, int O, typename S>
auto operator/(map<R, O, S> const& lhs, std::convertible_to<typename R::value_type> auto s)
{
return lhs.base() / s;
}
#else
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
typename map<R1,O1,S1>::concrete_type operator+(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
using concrete_type = typename map<R1,O1,S1>::concrete_type;
return concrete_type(lhs.base() + rhs.base());
}
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
typename map<R1,O1,S1>::concrete_type operator-(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
using concrete_type = typename map<R1,O1,S1>::concrete_type;
return concrete_type(lhs.base() - rhs.base());
}
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
typename map<R1,O1,S1>::concrete_type operator*(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
using concrete_type = typename map<R1,O1,S1>::concrete_type;
return concrete_type(lhs.base() * rhs.base());
}
template<typename R, int O, typename S>
typename map<R,O,S>::concrete_type operator*( map<R, O, S> const& lhs
, std::convertible_to<typename R::value_type> auto s
)
{
using concrete_type = typename map<R,O,S>::concrete_type;
return concrete_type(lhs.base() * s);
}
template<typename R, int O, typename S>
typename map<R,O,S>::concrete_type operator*( std::convertible_to<typename R::value_type> auto s
, map<R, O, S> const& rhs
)
{
using concrete_type = typename map<R,O,S>::concrete_type;
return concrete_type(rhs.base() * s);
}
template<typename R, int O, typename S>
typename map<R,O,S>::concrete_type operator/( map<R, O, S> const& lhs
, std::convertible_to<typename R::value_type> auto s
)
{
using concrete_type = typename map<R,O,S>::concrete_type;
return concrete_type(lhs.base() / s);
}
#endif
}