//================================================================================================== /* ROTGEN - Runtime Overlay for Eigen Copyright : CODE RECKONS SPDX-License-Identifier: BSL-1.0 */ //================================================================================================== #pragma once #include #include #if !defined(ROTGEN_FORCE_DYNAMIC) #include #endif namespace rotgen { // Primary template: mutable ref template class ref : private map { public: using parent = map; using value_type = typename T::value_type; using rotgen_tag = void; using rotgen_ref_tag = void; static constexpr int storage_order = T::storage_order; static constexpr int RowsAtCompileTime = T::RowsAtCompileTime; static constexpr int ColsAtCompileTime = T::ColsAtCompileTime; static constexpr bool IsVectorAtCompileTime = T::IsVectorAtCompileTime; using parent::evaluate; using parent::noalias; using parent::operator(); using parent::operator[]; using parent::rows; using parent::cols; using parent::size; using parent::data; using parent::sum; using parent::prod; using parent::mean; using parent::trace; using parent::transpose; using parent::cwiseAbs; using parent::cwiseAbs2; using parent::cwiseInverse; using parent::cwiseSqrt; using parent::maxCoeff; using parent::minCoeff; using parent::norm; using parent::normalize; using parent::squaredNorm; using parent::lpNorm; using parent::operator+=; using parent::operator-=; using parent::operator*=; using parent::operator/=; using parent::Zero; using parent::Constant; using parent::Random; using parent::Identity; using parent::setZero; using parent::setConstant; using parent::setRandom; using parent::setIdentity; using parent::outerStride; using parent::innerStride; using parent::operator=; using stride_type = typename parent::stride_type; parent const& base() const { return static_cast(*this); } parent& base() { return static_cast(*this); } template S, int R, int C, int O, int MR, int MC> ref(matrix& m) : parent(m.data(), m.rows(), m.cols(), strides(m)) { static_assert((O & 1) == storage_order, "ref: Incompatible storage layout"); } template ref(block&& b) : parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()}) { static_assert((Ref::storage_order & 1) == storage_order, "ref: Incompatible storage layout"); } template ref(block& b) : parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()}) { static_assert((Ref::storage_order & 1) == storage_order, "ref: Incompatible storage layout"); } template ref(map& b) : parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()}) { static_assert((Ref::storage_order & 1) == storage_order, "ref: Incompatible storage layout"); } template ref ( ref& b ) requires(std::same_as && (TT::storage_order & 1) == storage_order) : parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()}) {} ref(parent& m) : parent(m.data(), m.rows(), m.cols()) {} friend std::ostream& operator<<(std::ostream& os, ref const& r) { return os << r.base() << "\n"; } }; // Specialization for const matrix type template class ref : private map { public: using parent = map; using value_type = typename T::value_type; using rotgen_tag = void; static constexpr int storage_order = T::storage_order; static constexpr int RowsAtCompileTime = T::RowsAtCompileTime; static constexpr int ColsAtCompileTime = T::ColsAtCompileTime; static constexpr bool IsVectorAtCompileTime = T::IsVectorAtCompileTime; using parent::evaluate; using parent::noalias; using parent::operator(); using parent::operator[]; using parent::rows; using parent::cols; using parent::size; using parent::data; using parent::sum; using parent::prod; using parent::mean; using parent::trace; using parent::transpose; using parent::cwiseAbs; using parent::cwiseAbs2; using parent::cwiseInverse; using parent::cwiseSqrt; using parent::maxCoeff; using parent::minCoeff; using parent::norm; using parent::normalize; using parent::squaredNorm; using parent::lpNorm; using parent::operator+=; using parent::operator-=; using parent::operator*=; using parent::operator/=; using parent::Zero; using parent::Constant; using parent::Random; using parent::Identity; using parent::outerStride; using parent::innerStride; using parent::operator=; using stride_type = typename parent::stride_type; static constexpr bool has_static_storage = parent::has_static_storage; parent const& base() const { return static_cast(*this); } template S, int R, int C, int O, int MR, int MC> ref(matrix const& m) requires((O & 1) == storage_order) : parent(m.data(), m.rows(), m.cols(), strides(m)) {} template ref ( block const& b ) requires(std::same_as && (Ref::storage_order & 1) == storage_order) : parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()}) {} template ref ( map const& b ) requires(std::same_as && (Ref::storage_order & 1) == storage_order) : parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()}) {} template ref ( ref const& b ) requires(std::same_as && (TT::storage_order & 1) == storage_order) : parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()}) {} ref(parent const& m) : parent(m.data(), m.rows(), m.cols()) {} friend std::ostream& operator<<(std::ostream& os, ref const& r) { return os << r.base() << "\n"; } }; template ref(matrix&) -> ref>; template ref(block& b) -> ref; template ref(matrix const&) -> ref const>; template ref(block const& b) -> ref; template bool operator==(ref lhs, ref rhs) { return lhs.base() == rhs.base(); } template bool operator!=(ref lhs, ref rhs) { return lhs.base() != rhs.base(); } template auto operator+(ref lhs, ref rhs) -> decltype(lhs.base() + rhs.base()) { return lhs.base() + rhs.base(); } template auto operator+=(ref lhs, ref rhs) -> decltype(lhs.base() += rhs.base()) { return lhs.base() += rhs.base(); } template auto operator-(ref lhs, ref rhs) { return lhs.base() - rhs.base(); } template auto operator-=(ref lhs, ref rhs) -> decltype(lhs.base() -= rhs.base()) { return lhs.base() -= rhs.base(); } template auto operator*(ref lhs, ref rhs) { return lhs.base() * rhs.base(); } template auto operator*=(ref lhs, ref rhs) -> decltype(lhs.base() *= rhs.base()) { return lhs.base() *= rhs.base(); } template auto operator*(ref lhs, std::convertible_to auto s) { return lhs.base() * s; } template auto operator*(std::convertible_to auto s, ref rhs) { return s * rhs.base(); } template auto operator/(ref lhs, std::convertible_to auto s) { return lhs.base() / s; } template auto dot(ref lhs, ref rhs) { return lhs.base().dot(rhs.base()); } template auto min(ref lhs, ref rhs) -> decltype(lhs.base().cwiseMin(rhs.base())) { return lhs.base().cwiseMin(rhs.base()); } template auto min(ref lhs, std::convertible_to auto s) -> decltype(lhs.base().cwiseMin(s)) { return lhs.base().cwiseMin(s); } template auto min(std::convertible_to auto s,ref rhs) -> decltype(rhs.base().cwiseMin(s)) { return rhs.base().cwiseMin(s); } template auto max(ref lhs, ref rhs) -> decltype(lhs.base().cwiseMax(rhs.base())) { return lhs.base().cwiseMax(rhs.base()); } template auto max(ref lhs, std::convertible_to auto s) -> decltype(lhs.base().cwiseMax(s)) { return lhs.base().cwiseMax(s); } template auto max(std::convertible_to auto s,ref rhs) -> decltype(rhs.base().cwiseMax(s)) { return rhs.base().cwiseMax(s); } template auto mul(ref lhs, ref rhs) -> decltype(lhs.base().cwiseProduct(rhs.base())) { return lhs.base().cwiseProduct(rhs.base()); } template auto mul(ref lhs, std::convertible_to auto s) -> decltype(lhs * s) { return lhs * s; } template auto mul(std::convertible_to auto s,ref rhs) -> decltype(s * rhs) { return s * rhs; } template auto div(ref lhs, ref rhs) -> decltype(lhs.base().cwiseQuotient(rhs.base())) { return lhs.base().cwiseQuotient(rhs.base()); } template auto div(ref lhs, std::convertible_to auto s) -> decltype(lhs / s) { return lhs / s; } template auto inverse(ref lhs) -> decltype(lhs.base().inverse()) { return lhs.base().inverse(); } template auto cross(ref lhs, ref rhs) -> decltype(lhs.base().cross(rhs.base())) { return lhs.base().cross(rhs.base()); } //------------------------------------------------------------------------------------------- // Convert entity/eigen types to a proper ref so we can write less function overloads template struct generalize; template requires(std::is_arithmetic_v>) struct generalize { using type = std::remove_cvref_t; }; template using generalize_t = typename generalize::type; template struct generalize { static constexpr bool is_const = std::is_const_v; using base = matrix; using type = std::conditional_t, ref>; }; template struct generalize> { using type = ref; }; template struct generalize const> { using type = ref; }; template typename T::parent& base_of(T& a) { return a.base(); } template typename T::parent const& base_of(T const& a) { return a.base(); } template T base_of(T a) requires(std::is_arithmetic_v) { return a; } #if !defined(ROTGEN_FORCE_DYNAMIC) template struct generalize { static constexpr bool is_const = std::is_const_v; using concrete_type = decltype(std::declval().eval()); using base = matrix; using type = std::conditional_t, ref>; }; template auto const& base_of(T const& a) { return a; } template auto& base_of(T& a) { return a; } #endif }