rotgen/include/rotgen/container/ref/fixed.hpp
Joel Falcou 8e80d1d083 More specific fixes
See merge request oss/rotgen!47
2025-12-02 14:40:01 +01:00

543 lines
15 KiB
C++

//==================================================================================================
/*
ROTGEN - Runtime Overlay for Eigen
Copyright : CODE RECKONS
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once
#include <rotgen/detail/helpers.hpp>
#include <rotgen/detail/product.hpp>
#include <Eigen/Dense>
#include <type_traits>
namespace rotgen
{
namespace detail
{
template<typename T, int Options, typename Stride> struct compile_ref;
template<typename Scalar,
int Rows,
int Cols,
int Opts,
int MaxRows,
int MaxCols,
int Options,
typename Stride>
struct compile_ref<matrix<Scalar, Rows, Cols, Opts, MaxRows, MaxCols>,
Options,
Stride>
{
using base = Eigen::Matrix<Scalar, Rows, Cols, Opts, MaxRows, MaxCols>;
using type = Eigen::Ref<base, Options, Stride>;
};
template<typename Scalar,
int Rows,
int Cols,
int Opts,
int MaxRows,
int MaxCols,
int Options,
typename Stride>
struct compile_ref<matrix<Scalar, Rows, Cols, Opts, MaxRows, MaxCols> const,
Options,
Stride>
{
using base = Eigen::Matrix<Scalar, Rows, Cols, Opts, MaxRows, MaxCols>;
using type = Eigen::Ref<base const, Options, Stride>;
};
template<typename T, int Options, typename Stride>
using compile_ref_t = typename compile_ref<T, Options, Stride>::type;
template<typename T, int Options, typename Stride>
using compile_base_t = typename compile_ref<T, Options, Stride>::base;
}
template<typename T, int Options, typename Stride>
class ref : private detail::compile_ref_t<T, Options, Stride>
{
public:
using parent = detail::compile_ref_t<T, Options, Stride>;
using exact_base = detail::compile_base_t<T, Options, Stride>;
using referee = std::remove_const_t<T>;
using value_type = typename referee::value_type;
using rotgen_tag = void;
using rotgen_ref_tag = void;
using stride_type = Stride;
static constexpr int RowsAtCompileTime = parent::RowsAtCompileTime;
static constexpr int ColsAtCompileTime = parent::ColsAtCompileTime;
static constexpr int MaxRowsAtCompileTime = parent::MaxRowsAtCompileTime;
static constexpr int MaxColsAtCompileTime = parent::MaxColsAtCompileTime;
static constexpr int InnerStrideAtCompileTime =
parent::InnerStrideAtCompileTime;
static constexpr int OuterStrideAtCompileTime =
parent::OuterStrideAtCompileTime;
static constexpr bool IsRowMajor = parent::IsRowMajor;
static constexpr bool IsVectorAtCompileTime = parent::IsVectorAtCompileTime;
static constexpr int storage_order = IsRowMajor ? RowMajor : ColMajor;
static constexpr bool is_immutable = std::is_const_v<T>;
// Access to values
using parent::operator();
using parent::operator[];
// Size related functions
using parent::cols;
using parent::data;
using parent::innerStride;
using parent::outerStride;
using parent::rows;
using parent::size;
// Aliasing handling
auto evaluate() const { return T(base().eval()); }
decltype(auto) noalias() const
{
if constexpr (use_expression_templates) return base().noalias();
else return *this;
}
decltype(auto) noalias()
{
if constexpr (use_expression_templates) return base().noalias();
else return *this;
}
// Numeric functions
auto operator-() const { return detail::concretize<matrix>(-base()); }
auto cwiseAbs() const
{
return detail::concretize<matrix>(base().cwiseAbs());
}
auto cwiseAbs2() const
{
return detail::concretize<matrix>(base().cwiseAbs2());
}
auto cwiseInverse() const
{
return detail::concretize<matrix>(base().cwiseInverse());
}
auto cwiseSqrt() const
{
return detail::concretize<matrix>(base().cwiseSqrt());
}
// Reductions
using parent::lpNorm;
using parent::maxCoeff;
using parent::mean;
using parent::minCoeff;
using parent::norm;
using parent::prod;
using parent::squaredNorm;
using parent::sum;
using parent::trace;
// Compound Operators
template<typename A, int O, typename S>
ref& operator+=(ref<A, O, S> rhs)
requires(!is_immutable)
{
base() += rhs.base();
return *this;
}
template<typename A, int O, typename S>
ref& operator-=(ref<A, O, S> rhs)
requires(!is_immutable)
{
base() -= rhs.base();
return *this;
}
template<typename A, int O, typename S>
ref& operator*=(ref<A, O, S> rhs)
requires(!is_immutable)
{
base() *= rhs.base();
return *this;
}
ref& operator*=(std::convertible_to<value_type> auto s)
requires(!is_immutable)
{
base() *= s;
return *this;
}
ref& operator/=(std::convertible_to<value_type> auto s)
requires(!is_immutable)
{
base() /= s;
return *this;
}
// Shape modifications
auto normalized() const
requires(IsVectorAtCompileTime)
{
return detail::concretize<matrix>(base().normalized());
}
auto transpose() const
{
return detail::concretize<matrix>(base().transpose());
}
auto adjoint() const
{
return detail::concretize<matrix>(base().adjoint());
}
auto conjugate() const
{
return detail::concretize<matrix>(base().conjugate());
}
// In-place Shape modifications
using parent::adjointInPlace;
using parent::normalize;
using parent::transposeInPlace;
// Generators
static auto Zero() { return detail::concretize<matrix>(parent::Zero()); }
static auto Zero(int rows, int cols)
{
return detail::concretize<matrix>(parent::Zero(rows, cols));
}
static auto Ones() { return detail::concretize<matrix>(parent::Ones()); }
static auto Ones(int rows, int cols)
{
return detail::concretize<matrix>(parent::Ones(rows, cols));
}
static auto Constant(value_type value)
{
return detail::concretize<matrix>(parent::Constant(value));
}
static auto Constant(int rows, int cols, value_type value)
{
return detail::concretize<matrix>(parent::Constant(rows, cols, value));
}
static auto Random()
{
return detail::concretize<matrix>(parent::Random());
}
static auto Random(int rows, int cols)
{
return detail::concretize<matrix>(parent::Random(rows, cols));
}
static auto Identity()
{
return detail::concretize<matrix>(parent::Identity());
}
static auto Identity(int rows, int cols)
{
return detail::concretize<matrix>(parent::Identity(rows, cols));
}
ref& setOnes()
{
base() = parent::Ones(base().rows(), base().cols());
return *this;
}
ref& setZero()
{
base() = parent::Zero(base().rows(), base().cols());
return *this;
}
ref& setConstant(value_type value)
{
base() = parent::Constant(base().rows(), base().cols(), value);
return *this;
}
ref& setRandom()
{
base() = parent::Random(base().rows(), base().cols());
return *this;
}
ref& setIdentity()
{
base() = parent::Identity(base().rows(), base().cols());
return *this;
}
auto qr_solve(auto const& rhs) const
{
// Can't store the result of this .solve has it's some deep E.T
return detail::as_concrete_t<
decltype(base().colPivHouseholderQr().solve(rhs.base())), matrix>(
base().colPivHouseholderQr().solve(rhs.base()));
};
ref& operator=(concepts::entity auto const& e)
{
base() = e.base();
return *this;
}
parent const& base() const { return static_cast<parent const&>(*this); }
parent& base() { return static_cast<parent&>(*this); }
template<typename M>
requires(is_immutable)
ref(product<M> const& m) : ref(m.storage_)
{
}
template<typename S, int R, int C, int O, int MR, int MC>
ref(matrix<S, R, C, O, MR, MC>& m)
requires(requires { parent(m.base()); })
: parent(m.base())
{
}
template<typename S, int R, int C, int O, int MR, int MC>
ref(matrix<S, R, C, O, MR, MC>&& m)
requires(requires {
parent(std::forward<matrix<S, R, C, O, MR, MC>>(m.base()));
} && is_immutable)
: parent(std::forward<matrix<S, R, C, O, MR, MC>>(m.base()))
{
}
template<typename S, int R, int C, int O, int MR, int MC>
ref(matrix<S, R, C, O, MR, MC> const& m)
requires(requires { parent(m.base()); })
: parent(m.base())
{
}
template<typename Ref, int R, int C, bool I>
ref(block<Ref, R, C, I>& b)
requires(requires { parent(b.base()); })
: parent(b.base())
{
}
template<typename Ref, int R, int C, bool I>
ref(block<Ref, R, C, I> const& b)
requires(requires { parent(b.base()); })
: parent(b.base())
{
}
template<typename Ref, int O, typename S>
ref(map<Ref, O, S>& m)
requires(requires { parent(m.base()); })
: parent(m.base())
{
}
template<typename Ref, int O, typename S>
ref(map<Ref, O, S> const& m)
requires(requires { parent(m.base()); })
: parent(m.base())
{
}
template<typename TT, int OO, typename SS>
ref(ref<TT, OO, SS>& r)
requires(requires { parent(r.base()); })
: parent(r.base())
{
}
template<typename TT, int OO, typename SS>
ref(ref<TT, OO, SS>&& r)
requires(requires { parent(r.base()); })
: parent(r.base())
{
}
template<typename TT, int OO, typename SS>
ref(ref<TT, OO, SS> const& r)
requires(requires { parent(r.base()); })
: parent(r.base())
{
}
ref(parent& m) : parent(m) {}
ref(parent const& m)
requires(is_immutable)
: parent(m)
{
}
friend std::ostream& operator<<(std::ostream& os, ref const& r)
{
return os << r.base();
}
friend std::ostream& operator<<(std::ostream& os, format<ref> const& r)
{
return os << format{r.matrix_.base(), r.format_};
}
};
//============================================================================
// Deduction Guides
//============================================================================
template<typename S, int R, int C, int O, int MR, int MC>
ref(matrix<S, R, C, O, MR, MC>&) -> ref<matrix<S, R, C, O, MR, MC>>;
template<typename Ref, int R, int C, bool I>
ref(block<Ref, R, C, I>& b) -> ref<Ref>;
template<typename S, int R, int C, int O, int MR, int MC>
ref(matrix<S, R, C, O, MR, MC> const&)
-> ref<matrix<S, R, C, O, MR, MC> const>;
template<typename Ref, int R, int C, bool I>
ref(block<Ref, R, C, I> const& b) -> ref<Ref const>;
//============================================================================
// Operators
//============================================================================
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator+(ref<A, O, S> lhs, ref<B, P, T> rhs)
{
return detail::concretize<matrix>(lhs.base() + rhs.base());
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator-(ref<A, O, S> lhs, ref<B, P, T> rhs)
{
return detail::concretize<matrix>(lhs.base() - rhs.base());
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator*(ref<A, O, S> lhs, ref<B, P, T> rhs)
{
auto p = lhs.base() * rhs.base();
using concrete_type = detail::as_concrete_t<decltype(p), matrix>;
if constexpr (concrete_type::SizeAtCompileTime == 1)
return product{concrete_type{p}};
else if constexpr (concrete_type::SizeAtCompileTime == 0)
return concrete_type{};
else return concrete_type{p};
}
template<typename A, int O, typename S>
auto operator*(ref<A, O, S> lhs,
std::convertible_to<typename A::value_type> auto s)
{
return detail::concretize<matrix>(lhs.base() * s);
}
template<typename A, int O, typename S>
auto operator/(ref<A, O, S> lhs,
std::convertible_to<typename A::value_type> auto s)
{
return detail::concretize<matrix>(lhs.base() / s);
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto mul(ref<A, O, S> lhs, ref<B, P, T> rhs)
{
return detail::concretize<matrix>(lhs.base().cwiseProduct(rhs.base()));
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto div(ref<A, O, S> lhs, ref<B, P, T> rhs)
{
return detail::concretize<matrix>(lhs.base().cwiseQuotient(rhs.base()));
}
//============================================================================
// Functions
//============================================================================
template<typename A, int O, typename S, typename B, int P, typename T>
auto min(ref<A, O, S> lhs, ref<B, P, T> rhs)
{
return detail::concretize<matrix>(lhs.base().cwiseMin(rhs.base()));
}
template<typename A, int O, typename S>
auto min(ref<A, O, S> lhs, std::convertible_to<typename A::value_type> auto s)
{
return detail::concretize<matrix>(lhs.base().cwiseMin(s));
}
template<typename A, int O, typename S>
auto min(std::convertible_to<typename A::value_type> auto s, ref<A, O, S> rhs)
{
return detail::concretize<matrix>(rhs.base().cwiseMin(s));
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto max(ref<A, O, S> lhs, ref<B, P, T> rhs)
{
return detail::concretize<matrix>(lhs.base().cwiseMax(rhs.base()));
}
template<typename A, int O, typename S>
auto max(ref<A, O, S> lhs, std::convertible_to<typename A::value_type> auto s)
{
return detail::concretize<matrix>(lhs.base().cwiseMax(s));
}
template<typename A, int O, typename S>
auto max(std::convertible_to<typename A::value_type> auto s, ref<A, O, S> rhs)
{
return detail::concretize<matrix>(rhs.base().cwiseMax(s));
}
template<typename A, int O, typename S> auto inverse(ref<A, O, S> lhs)
{
return detail::concretize<matrix>(lhs.base().inverse());
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto cross(ref<A, O, S> lhs, ref<B, P, T> rhs)
{
return detail::concretize<matrix>(lhs.base().cross(rhs.base()));
}
//-------------------------------------------------------------------------------------------
// Convert entity/eigen types to a proper ref so we can write less function
// overloads
template<typename T> struct generalize;
template<concepts::eigen_compatible T> struct generalize<T>
{
static constexpr bool is_const = std::is_const_v<T>;
using concrete_type = decltype(std::declval<T>().eval());
using base = matrix<typename T::Scalar,
T::RowsAtCompileTime,
T::ColsAtCompileTime,
concrete_type::Options & 1>;
using type = std::conditional_t<is_const, ref<base const>, ref<base>>;
};
template<concepts::eigen_compatible T> decltype(auto) base_of(T&& a)
{
return ROTGEN_FWD(a);
}
}