rotgen/include/rotgen/fixed/map.hpp
Joel Falcou ddf8816c5b Implement dot
See merge request oss/rotgen!31
2025-09-29 18:58:12 +02:00

474 lines
16 KiB
C++

//==================================================================================================
/*
ROTGEN - Runtime Overlay for Eigen
Copyright : CODE RECKONS
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once
#include <Eigen/Dense>
#include <iostream>
namespace rotgen
{
namespace detail
{
template<typename Ref, int Options, bool isConst>
struct compute_map_type
{
using base = Eigen::Matrix< typename Ref::value_type
, Ref::RowsAtCompileTime, Ref::ColsAtCompileTime
, Ref::storage_order
>;
using ref_t = std::conditional_t<isConst, base const, base>;
using type = Eigen::Map<ref_t, Options, Eigen::Stride<-1,-1>>;
};
template<typename Ref, int Options, bool isConst>
using map_type = typename compute_map_type<Ref,Options,isConst>::type;
}
template<typename Ref, int Options = Unaligned, typename Stride = stride>
class map : private detail::map_type<std::remove_const_t<Ref>, Options, std::is_const_v<Ref>>
{
public:
using rotgen_tag = void;
using parent = detail::map_type<std::remove_const_t<Ref>, Options, std::is_const_v<Ref>>;
using value_type = typename std::remove_const_t<Ref>::value_type;
using concrete_type = typename std::remove_const_t<Ref>::concrete_type;
static constexpr Index RowsAtCompileTime = Ref::RowsAtCompileTime;
static constexpr Index ColsAtCompileTime = Ref::ColsAtCompileTime;
static constexpr bool IsVectorAtCompileTime = Ref::IsVectorAtCompileTime;
static constexpr bool has_static_storage = Ref::has_static_storage;
static constexpr int storage_order = Ref::storage_order;
static constexpr bool is_immutable = std::is_const_v<Ref>;
static constexpr bool is_defined_static = Ref::is_defined_static;
template<typename ET>
using as_concrete_type = as_concrete_t<ET, matrix>;
using ptr_type = std::conditional_t<is_immutable, value_type const*, value_type*>;
using stride_type = Stride;
map(const map&) = default;
map(map&&) = default;
map& operator=(const map&) = default;
map& operator=(map&&) = default;
map(ptr_type ptr, Index r, Index c, stride_type s) : parent(ptr, r, c, strides<storage_order>(s,r,c)) {}
map(ptr_type ptr, Index r, Index c)
: parent( ptr, r, c
, [&]()
{
if constexpr(!std::same_as<Stride,stride>) return strides<storage_order>(Stride{},r,c);
else return strides<storage_order>(r,c);
}()
)
{}
map(ptr_type ptr, stride_type s) requires(RowsAtCompileTime!=-1 && ColsAtCompileTime!=-1)
: parent(ptr, strides<storage_order>(s,RowsAtCompileTime,ColsAtCompileTime))
{}
map(ptr_type ptr, Index sz) requires(RowsAtCompileTime==1 || ColsAtCompileTime==1)
: map(ptr, RowsAtCompileTime==1?1:sz, ColsAtCompileTime==1?1:sz)
{}
map(ptr_type ptr) requires(RowsAtCompileTime!=-1 && ColsAtCompileTime!=-1)
: map( ptr, RowsAtCompileTime, ColsAtCompileTime )
{}
map(concepts::entity auto const& other) : parent(other.base())
{}
map& operator=(concepts::entity auto const& other)
{
parent::operator=(other.base());
return *this;
}
parent& base() { return static_cast<parent&>(*this); }
parent const& base() const { return static_cast<const parent&>(*this); }
auto evaluate() const
{
auto res = static_cast<parent const &>(*this).eval();
return as_concrete_type<decltype(res)>(res);
}
decltype(auto) noalias() const
{
if constexpr(use_expression_templates) return base().noalias();
else return *this;
}
decltype(auto) noalias()
{
if constexpr(use_expression_templates) return base().noalias();
else return *this;
}
value_type& operator()(Index i, Index j) requires(!is_immutable)
{
return base()(i,j);
}
value_type& operator()(Index i) requires(!is_immutable && IsVectorAtCompileTime)
{
return base().data()[i];
}
value_type& operator[](Index i) requires(!is_immutable && IsVectorAtCompileTime)
{
return (*this)(i);
}
value_type operator()(Index i, Index j) const { return base()(i,j); }
value_type operator()(Index i) const requires(IsVectorAtCompileTime) { return base().data()[i]; }
value_type operator[](Index i) const requires(IsVectorAtCompileTime) { return (*this)(i); }
template<typename R2, int O2, typename S2>
map& operator+=(map<R2,O2,S2> const& rhs) requires(!is_immutable)
{
base() += rhs.base();
return *this;
}
template<typename R2, int O2, typename S2>
map& operator-=(map<R2,O2,S2> const& rhs) requires(!is_immutable)
{
base() -= rhs.base();
return *this;
}
template<typename R2, int O2, typename S2>
map& operator*=(map<R2,O2,S2> const& rhs) requires(!is_immutable)
{
base() *= rhs.base();
return *this;
}
map& operator*=(value_type rhs)
{
base() *= rhs;
return *this;
}
map& operator/=(value_type rhs)
{
base() /= rhs;
return *this;
}
auto cwiseMin(map const& rhs) const
{
if constexpr(!use_expression_templates) return concrete_type{parent::cwiseMin(rhs.base())};
else return base().cwiseMin(rhs.base());
}
auto cwiseMin(value_type rhs) const
{
if constexpr(!use_expression_templates) return concrete_type{parent::cwiseMin(rhs)};
else return base().cwiseMin(rhs);
}
auto cwiseMax(map const& rhs) const
{
if constexpr(!use_expression_templates) return concrete_type{parent::cwiseMax(rhs.base())};
else return base().cwiseMax(rhs.base());
}
auto cwiseMax(value_type rhs) const
{
if constexpr(!use_expression_templates) return concrete_type{parent::cwiseMax(rhs)};
else return base().cwiseMax(rhs);
}
auto cwiseProduct(map const& rhs) const
{
if constexpr(!use_expression_templates) return concrete_type{parent::cwiseProduct(rhs.base())};
else return base().cwiseProduct(rhs.base());
}
auto cwiseQuotient(map const& rhs) const
{
if constexpr(!use_expression_templates) return concrete_type{parent::cwiseQuotient(rhs.base())};
else return base().cwiseQuotient(rhs.base());
}
auto cwiseAbs() const
{
if constexpr(!use_expression_templates) return concrete_type{parent::cwiseAbs()};
else return base().cwiseAbs();
}
auto cwiseAbs2() const
{
if constexpr(!use_expression_templates) return concrete_type{parent::cwiseAbs2()};
else return base().cwiseAbs2();
}
auto cwiseInverse() const
{
if constexpr(!use_expression_templates) return concrete_type{parent::cwiseInverse()};
else return base().cwiseInverse();
}
auto cwiseSqrt() const
{
if constexpr(!use_expression_templates) return concrete_type{parent::cwiseSqrt()};
else return base().cwiseSqrt();
}
auto cross(map const& rhs) const
{
if constexpr(!use_expression_templates) return concrete_type{parent::cross(rhs.base())};
else return base().cross(rhs.base());
}
auto inverse() const
{
if constexpr(use_expression_templates) return base().inverse();
else return as_concrete_type<decltype(base().inverse())>(base().inverse());
}
auto normalized() const requires(IsVectorAtCompileTime)
{
if constexpr(use_expression_templates) return base().normalized();
else return as_concrete_type<decltype(base().normalized())>(base().normalized());
}
auto transpose() const
{
if constexpr(use_expression_templates) return base().transpose();
else return as_concrete_type<decltype(base().transpose())>(base().transpose());
}
auto adjoint() const
{
if constexpr(use_expression_templates) return base().adjoint();
else return as_concrete_type<decltype(base().adjoint())>(base().adjoint());
}
auto conjugate() const
{
if constexpr(use_expression_templates) return base().conjugate();
else return as_concrete_type<decltype(base().conjugate())>(base().conjugate());
}
void normalize() requires(IsVectorAtCompileTime)
{
base().normalize();
}
void transposeInPlace() { base().transposeInPlace(); }
void adjointInPlace() { base().adjointInPlace(); }
auto qr_solve(auto const& rhs) const
{
return concrete_type(base().colPivHouseholderQr().solve(rhs.base()));
};
static auto Zero() requires( requires {Ref::Zero();} ) { return Ref::Zero(); }
static auto Zero(int rows, int cols) { return Ref::Zero(rows,cols); }
static auto Ones() requires( requires {Ref::Ones();} ) { return Ref::Ones(); }
static auto Ones(int rows, int cols) { return Ref::Ones(rows,cols); }
static auto Constant(value_type value) requires( requires {Ref::Constant(value);} )
{ return Ref::Constant(value); }
static auto Constant(int rows, int cols, value_type value) { return Ref::Constant(rows, cols, value); }
static auto Random() requires( requires {Ref::Random();} ) { return Ref::Random(); }
static auto Random(int rows, int cols) { return Ref::Random(rows, cols); }
static auto Identity() requires( requires {Ref::Identity();} ) { return Ref::Identity(); }
static auto Identity(int rows, int cols) { return Ref::Identity(rows, cols); }
map& setOnes()
{
base() = parent::Ones(base().rows(), base().cols());
return *this;
}
map& setZero()
{
base() = parent::Zero(base().rows(), base().cols());
return *this;
}
map& setConstant(value_type value)
{
base() = parent::Constant(base().rows(), base().cols(), value);
return *this;
}
map& setRandom()
{
base() = parent::Random(base().rows(), base().cols());
return *this;
}
map& setIdentity()
{
base() = parent::Identity(base().rows(), base().cols());
return *this;
}
using parent::innerStride;
using parent::outerStride;
using parent::rows;
using parent::cols;
using parent::size;
using parent::data;
using parent::sum;
using parent::mean;
using parent::prod;
using parent::trace;
using parent::norm;
using parent::squaredNorm;
auto minCoeff() const { return parent::minCoeff(); }
auto maxCoeff() const { return parent::maxCoeff(); }
template<std::integral IndexType>
auto minCoeff(IndexType* row, IndexType* col) const
{
Index r,c;
auto result = parent::minCoeff(&r, &c);
*row = r;
*col = c;
return result;
}
template<std::integral IndexType>
auto maxCoeff(IndexType* row, IndexType* col) const
{
Index r,c;
auto result = parent::maxCoeff(&r, &c);
*row = r;
*col = c;
return result;
}
template<typename R2, int O2, typename S2>
value_type dot(map<R2,O2,S2> const& rhs) const
{
return base().dot(rhs.base());
}
template<int P> value_type lpNorm() const
{
static_assert(P == 1 || P == 2 || P == Infinity);
return parent::template lpNorm<P>();
}
};
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
bool operator==(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
return lhs.base() == rhs.base();
}
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
bool operator!=(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
return lhs.base() != rhs.base();
}
#if defined(ROTGEN_ENABLE_EXPRESSION_TEMPLATES)
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
auto operator+(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
return lhs.base() + rhs.base();
}
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
auto operator-(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
return lhs.base() - rhs.base();
}
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
auto operator*(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
return lhs.base() * rhs.base();
}
template<typename R, int O, typename S>
auto operator*( map<R, O, S> const& lhs, std::convertible_to<typename R::value_type> auto s)
{
return lhs.base() * s;
}
template<typename R, int O, typename S>
auto operator*(std::convertible_to<typename R::value_type> auto s, map<R, O, S> const& rhs)
{
return s * rhs.base();
}
template<typename R, int O, typename S>
auto operator/(map<R, O, S> const& lhs, std::convertible_to<typename R::value_type> auto s)
{
return lhs.base() / s;
}
#else
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
typename map<R1,O1,S1>::concrete_type operator+(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
using concrete_type = typename map<R1,O1,S1>::concrete_type;
return concrete_type(lhs.base() + rhs.base());
}
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
typename map<R1,O1,S1>::concrete_type operator-(map<R1, O1, S1> const& lhs, map<R2,O2,S2> const& rhs)
{
using concrete_type = typename map<R1,O1,S1>::concrete_type;
return concrete_type(lhs.base() - rhs.base());
}
template<typename R1, typename R2, int O1, typename S1, int O2, typename S2>
matrix<typename R1::value_type,R1::RowsAtCompileTime,R2::ColsAtCompileTime>
operator*(map<R1,O1,S1> const& lhs, map<R2,O2,S2> const& rhs)
{
using concrete_type = matrix< typename R1::value_type
, R1::RowsAtCompileTime,R2::ColsAtCompileTime
>;
return concrete_type(lhs.base() * rhs.base());
}
template<typename R, int O, typename S>
typename map<R,O,S>::concrete_type operator*( map<R, O, S> const& lhs
, std::convertible_to<typename R::value_type> auto s
)
{
using concrete_type = typename map<R,O,S>::concrete_type;
return concrete_type(lhs.base() * s);
}
template<typename R, int O, typename S>
typename map<R,O,S>::concrete_type operator*( std::convertible_to<typename R::value_type> auto s
, map<R, O, S> const& rhs
)
{
using concrete_type = typename map<R,O,S>::concrete_type;
return concrete_type(rhs.base() * s);
}
template<typename R, int O, typename S>
typename map<R,O,S>::concrete_type operator/( map<R, O, S> const& lhs
, std::convertible_to<typename R::value_type> auto s
)
{
using concrete_type = typename map<R,O,S>::concrete_type;
return concrete_type(lhs.base() / s);
}
#endif
}