rotgen/include/rotgen/common/ref.hpp
2025-09-28 16:15:15 +02:00

417 lines
13 KiB
C++

//==================================================================================================
/*
ROTGEN - Runtime Overlay for Eigen
Copyright : CODE RECKONS
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once
#include <type_traits>
#include <cassert>
#if !defined(ROTGEN_FORCE_DYNAMIC)
#include <Eigen/Dense>
#endif
namespace rotgen
{
// Primary template: mutable ref
template<typename T, int Options, typename Stride>
class ref : private map<T, Options, Stride>
{
public:
using parent = map<T, Options, Stride>;
using value_type = typename T::value_type;
using rotgen_tag = void;
using rotgen_ref_tag = void;
static constexpr int storage_order = T::storage_order;
static constexpr int RowsAtCompileTime = T::RowsAtCompileTime;
static constexpr int ColsAtCompileTime = T::ColsAtCompileTime;
static constexpr bool IsVectorAtCompileTime = T::IsVectorAtCompileTime;
using parent::evaluate;
using parent::noalias;
using parent::operator();
using parent::operator[];
using parent::rows;
using parent::cols;
using parent::size;
using parent::data;
using parent::sum;
using parent::prod;
using parent::mean;
using parent::trace;
using parent::transpose;
using parent::cwiseAbs;
using parent::cwiseAbs2;
using parent::cwiseInverse;
using parent::cwiseSqrt;
using parent::maxCoeff;
using parent::minCoeff;
using parent::norm;
using parent::normalize;
using parent::squaredNorm;
using parent::lpNorm;
using parent::operator+=;
using parent::operator-=;
using parent::operator*=;
using parent::operator/=;
using parent::Zero;
using parent::Constant;
using parent::Random;
using parent::Identity;
using parent::setZero;
using parent::setConstant;
using parent::setRandom;
using parent::setIdentity;
using parent::operator=;
using stride_type = typename parent::stride_type;
parent const& base() const { return static_cast<parent const&>(*this); }
parent& base() { return static_cast<parent&>(*this); }
template<std::same_as<value_type> S, int R, int C, int O, int MR, int MC>
ref(matrix<S, R, C, O, MR, MC>& m)
: parent(m.data(), m.rows(), m.cols(), strides(m))
{
static_assert((O & 1) == storage_order, "ref: Incompatible storage layout");
}
template<typename Ref, int R, int C, bool I>
ref(block<Ref,R,C,I>&& b) : parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()})
{
static_assert((Ref::storage_order & 1) == storage_order, "ref: Incompatible storage layout");
}
template<typename Ref, int R, int C, bool I>
ref(block<Ref,R,C,I>& b) : parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()})
{
static_assert((Ref::storage_order & 1) == storage_order, "ref: Incompatible storage layout");
}
template<typename Ref, int O, typename S>
ref(map<Ref,O,S>& b) : parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()})
{
static_assert((Ref::storage_order & 1) == storage_order, "ref: Incompatible storage layout");
}
ref(parent& m) : parent(m.data(), m.rows(), m.cols()) {}
friend std::ostream& operator<<(std::ostream& os, ref const& r)
{
return os << r.base() << "\n";
}
};
// Specialization for const matrix type
template<typename T, int Options, typename Stride>
class ref<const T, Options,Stride> : private map<const T, Options,Stride>
{
public:
using parent = map<const T, Options,Stride>;
using value_type = typename T::value_type;
using rotgen_tag = void;
static constexpr int storage_order = T::storage_order;
static constexpr int RowsAtCompileTime = T::RowsAtCompileTime;
static constexpr int ColsAtCompileTime = T::ColsAtCompileTime;
static constexpr bool IsVectorAtCompileTime = T::IsVectorAtCompileTime;
using parent::evaluate;
using parent::noalias;
using parent::operator();
using parent::operator[];
using parent::rows;
using parent::cols;
using parent::size;
using parent::data;
using parent::sum;
using parent::prod;
using parent::mean;
using parent::trace;
using parent::transpose;
using parent::cwiseAbs;
using parent::cwiseAbs2;
using parent::cwiseInverse;
using parent::cwiseSqrt;
using parent::maxCoeff;
using parent::minCoeff;
using parent::norm;
using parent::normalize;
using parent::squaredNorm;
using parent::lpNorm;
using parent::operator+=;
using parent::operator-=;
using parent::operator*=;
using parent::operator/=;
using parent::Zero;
using parent::Constant;
using parent::Random;
using parent::Identity;
using parent::operator=;
using stride_type = typename parent::stride_type;
static constexpr bool has_static_storage = parent::has_static_storage;
parent const& base() const { return static_cast<parent const&>(*this); }
template<std::same_as<value_type> S, int R, int C, int O, int MR, int MC>
ref(matrix<S, R, C, O, MR, MC> const& m)
requires((O & 1) == storage_order)
: parent(m.data(), m.rows(), m.cols(), strides(m))
{}
template<typename Ref, int R, int C, bool I>
ref ( block<Ref,R,C,I> const& b )
requires(std::same_as<value_type, typename Ref::value_type> && (Ref::storage_order & 1) == storage_order)
: parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()})
{}
template<typename Ref, int O, typename S>
ref ( map<Ref,O,S> const& b )
requires(std::same_as<value_type, typename Ref::value_type> && (Ref::storage_order & 1) == storage_order)
: parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()})
{}
ref(parent const& m) : parent(m.data(), m.rows(), m.cols()) {}
friend std::ostream& operator<<(std::ostream& os, ref const& r)
{
return os << r.base() << "\n";
}
};
template<typename S, int R, int C, int O, int MR, int MC>
ref(matrix<S, R, C, O, MR, MC>&) -> ref<matrix<S>>;
template<typename Ref, int R, int C, bool I>
ref(block<Ref,R,C,I>& b) -> ref<Ref>;
template<typename S, int R, int C, int O, int MR, int MC>
ref(matrix<S, R, C, O, MR, MC> const&) -> ref<matrix<S> const>;
template<typename Ref, int R, int C, bool I>
ref(block<Ref,R,C,I> const& b) -> ref<Ref const>;
template<typename A, int O, typename S, typename B, int P, typename T>
bool operator==(ref<A,O,S> lhs, ref<B,P,T> rhs)
{
return lhs.base() == rhs.base();
}
template<typename A, int O, typename S, typename B, int P, typename T>
bool operator!=(ref<A,O,S> lhs, ref<B,P,T> rhs)
{
return lhs.base() != rhs.base();
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator+(ref<A,O,S> lhs, ref<B,P,T> rhs) -> decltype(lhs.base() + rhs.base())
{
return lhs.base() + rhs.base();
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator+=(ref<A,O,S> lhs, ref<B,P,T> rhs) -> decltype(lhs.base() += rhs.base())
{
return lhs.base() += rhs.base();
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator-(ref<A,O,S> lhs, ref<B,P,T> rhs)
{
return lhs.base() - rhs.base();
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator-=(ref<A,O,S> lhs, ref<B,P,T> rhs) -> decltype(lhs.base() -= rhs.base())
{
return lhs.base() -= rhs.base();
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator*(ref<A,O,S> lhs, ref<B,P,T> rhs)
{
return lhs.base() * rhs.base();
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator*=(ref<A,O,S> lhs, ref<B,P,T> rhs) -> decltype(lhs.base() *= rhs.base())
{
return lhs.base() *= rhs.base();
}
template<typename A, int O, typename S>
auto operator*(ref<A,O,S> lhs, std::convertible_to<typename A::value_type> auto s)
{
return lhs.base() * s;
}
template<typename A, int O, typename S>
auto operator*(std::convertible_to<typename A::value_type> auto s, ref<A,O,S> rhs)
{
return s * rhs.base();
}
template<typename A, int O, typename S>
auto operator/(ref<A,O,S> lhs, std::convertible_to<typename A::value_type> auto s)
{
return lhs.base() / s;
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto min(ref<A,O,S> lhs, ref<B,P,T> rhs) -> decltype(lhs.base().cwiseMin(rhs.base()))
{
return lhs.base().cwiseMin(rhs.base());
}
template<typename A, int O, typename S>
auto min(ref<A,O,S> lhs, std::convertible_to<typename A::value_type> auto s) -> decltype(lhs.base().cwiseMin(s))
{
return lhs.base().cwiseMin(s);
}
template<typename A, int O, typename S>
auto min(std::convertible_to<typename A::value_type> auto s,ref<A,O,S> rhs) -> decltype(rhs.base().cwiseMin(s))
{
return rhs.base().cwiseMin(s);
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto max(ref<A,O,S> lhs, ref<B,P,T> rhs) -> decltype(lhs.base().cwiseMax(rhs.base()))
{
return lhs.base().cwiseMax(rhs.base());
}
template<typename A, int O, typename S>
auto max(ref<A,O,S> lhs, std::convertible_to<typename A::value_type> auto s) -> decltype(lhs.base().cwiseMax(s))
{
return lhs.base().cwiseMax(s);
}
template<typename A, int O, typename S>
auto max(std::convertible_to<typename A::value_type> auto s,ref<A,O,S> rhs) -> decltype(rhs.base().cwiseMax(s))
{
return rhs.base().cwiseMax(s);
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto mul(ref<A,O,S> lhs, ref<B,P,T> rhs) -> decltype(lhs.base().cwiseProduct(rhs.base()))
{
return lhs.base().cwiseProduct(rhs.base());
}
template<typename A, int O, typename S>
auto mul(ref<A,O,S> lhs, std::convertible_to<typename A::value_type> auto s) -> decltype(lhs * s)
{
return lhs * s;
}
template<typename A, int O, typename S>
auto mul(std::convertible_to<typename A::value_type> auto s,ref<A,O,S> rhs) -> decltype(s * rhs)
{
return s * rhs;
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto div(ref<A,O,S> lhs, ref<B,P,T> rhs) -> decltype(lhs.base().cwiseQuotient(rhs.base()))
{
return lhs.base().cwiseQuotient(rhs.base());
}
template<typename A, int O, typename S>
auto div(ref<A,O,S> lhs, std::convertible_to<typename A::value_type> auto s) -> decltype(lhs / s)
{
return lhs / s;
}
template<typename A, int O, typename S>
auto inverse(ref<A,O,S> lhs) -> decltype(lhs.base().inverse())
{
return lhs.base().inverse();
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto cross(ref<A,O,S> lhs, ref<B,P,T> rhs) -> decltype(lhs.base().cross(rhs.base()))
{
return lhs.base().cross(rhs.base());
}
//-------------------------------------------------------------------------------------------
// Convert entity/eigen types to a proper ref so we can write less function overloads
template<typename T> struct generalize;
template<typename T>
requires(std::is_arithmetic_v<std::remove_cvref_t<T>>)
struct generalize<T>
{
using type = std::remove_cvref_t<T>;
};
template<typename T>
using generalize_t = typename generalize<T>::type;
template<concepts::entity T> struct generalize<T>
{
static constexpr bool is_const = std::is_const_v<T>;
using base = matrix<typename T::value_type,T::RowsAtCompileTime,T::ColsAtCompileTime,T::storage_order>;
using type = std::conditional_t<is_const,ref<base const>, ref<base>>;
};
template<typename T, int O, typename S>
struct generalize<ref<T,O,S>>
{
using type = ref<T,O,S>;
};
template<typename T, int O, typename S>
struct generalize<ref<T,O,S> const>
{
using type = ref<T,O,S>;
};
template<concepts::entity T>
typename T::parent& base_of(T& a)
{
return a.base();
}
template<concepts::entity T>
typename T::parent const& base_of(T const& a)
{
return a.base();
}
template<typename T>
T base_of(T a) requires(std::is_arithmetic_v<T>)
{
return a;
}
#if !defined(ROTGEN_FORCE_DYNAMIC)
template<concepts::eigen_compatible T> struct generalize<T>
{
static constexpr bool is_const = std::is_const_v<T>;
using concrete_type = decltype(std::declval<T>().eval());
using base = matrix<typename T::Scalar,T::RowsAtCompileTime,T::ColsAtCompileTime,concrete_type::Options&1>;
using type = std::conditional_t<is_const,ref<base const>, ref<base>>;
};
template<concepts::eigen_compatible T>
auto const& base_of(T const& a)
{
return a;
}
template<concepts::eigen_compatible T>
auto& base_of(T& a)
{
return a;
}
#endif
}