Add support for custom strides on map

See merge request oss/rotgen!14
This commit is contained in:
Joel Falcou 2025-08-15 16:49:33 +02:00
parent 87d4bc0585
commit c8fb0f476c
6 changed files with 220 additions and 69 deletions

View file

@ -17,11 +17,11 @@
namespace rotgen
{
// Primary template: mutable ref
template<typename T, int Options = T::storage_order>
class ref : private map<T, Options>
template<typename T, int Options = T::storage_order, typename Stride = stride>
class ref : private map<T, Options, Stride>
{
public:
using parent = map<T, Options>;
using parent = map<T, Options, Stride>;
using value_type = typename T::value_type;
using rotgen_tag = void;
@ -74,18 +74,6 @@ namespace rotgen
static_assert((Ref::Options & 1) == Options, "ref: Incompatible storage layout");
}
// #if !defined(ROTGEN_FORCE_DYNAMIC)
// template<typename OtherDerived>
// ref(const Eigen::MatrixBase<OtherDerived>& b)
// : parent(b.data(), b.rows(), b.cols()/*, stride_type{b.outerStride(),b.innerStride()}*/)
// {
// using Str = typename parent::stride_type;
// std::cerr << "Stride compile-time: Outer=" << Str::OuterStrideAtCompileTime
// << " Inner=" << Str::InnerStrideAtCompileTime <<std::endl;
// std::cerr << "runtime: outer=" << b.outerStride() << " inner=" << b.innerStride() << std::endl;
// }
// #endif
ref(parent& m) : parent(m.data(), m.rows(), m.cols()) {}
friend std::ostream& operator<<(std::ostream& os, ref const& r)
@ -95,11 +83,11 @@ namespace rotgen
};
// Specialization for const matrix type
template<typename T, int Options>
class ref<const T, Options> : private map<const T, Options>
template<typename T, int Options, typename Stride>
class ref<const T, Options,Stride> : private map<const T, Options,Stride>
{
public:
using parent = map<const T, Options>;
using parent = map<const T, Options,Stride>;
using value_type = typename T::value_type;
using rotgen_tag = void;
@ -152,23 +140,8 @@ namespace rotgen
static_assert((Ref::Options & 1) == Options, "ref: Incompatible storage layout");
}
// #if !defined(ROTGEN_FORCE_DYNAMIC)
// template<typename OtherDerived>
// ref(const Eigen::MatrixBase<OtherDerived>& b)
// : parent(b.data(), b.rows(), b.cols())//, stride_type{b.outerStride(),b.innerStride()})
// {
// using Str = typename parent::stride_type;
// std::cerr << "Stride compile-time: Outer=" << Str::OuterStrideAtCompileTime
// << " Inner=" << Str::InnerStrideAtCompileTime <<std::endl;
// std::cerr << "runtime: outer=" << b.outerStride() << " inner=" << b.innerStride() << std::endl;
// }
// #endif
ref(parent const& m) : parent(m.data(), m.rows(), m.cols()) {}
// // From raw const buffer
// ref(value_type const* ptr, int r, int c) : parent(ptr, r, c) {}
friend std::ostream& operator<<(std::ostream& os, ref const& r)
{
return os << r.base() << "\n";
@ -187,50 +160,50 @@ namespace rotgen
template<typename Ref, int R, int C, bool I, int FS>
ref(block<Ref,R,C,I,FS> const& b) -> ref<Ref const>;
template<typename A, typename B>
bool operator==(ref<A> lhs, ref<B> rhs)
template<typename A, int O, typename S, typename B, int P, typename T>
bool operator==(ref<A,O,S> lhs, ref<B,P,T> rhs)
{
return lhs.base() == rhs.base();
}
template<typename A, typename B>
bool operator!=(ref<A> lhs, ref<B> rhs)
template<typename A, int O, typename S, typename B, int P, typename T>
bool operator!=(ref<A,O,S> lhs, ref<B,P,T> rhs)
{
return lhs.base() != rhs.base();
}
template<typename A, typename B>
auto operator+(ref<A> lhs, ref<B> rhs) -> decltype(lhs.base() + rhs.base())
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator+(ref<A,O,S> lhs, ref<B,P,T> rhs) -> decltype(lhs.base() + rhs.base())
{
return lhs.base() + rhs.base();
}
template<typename A, typename B>
auto operator-(ref<A> lhs, ref<B> rhs)
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator-(ref<A,O,S> lhs, ref<B,P,T> rhs)
{
return lhs.base() - rhs.base();
}
template<typename A, typename B>
auto operator*(ref<A> lhs, ref<B> rhs)
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator*(ref<A,O,S> lhs, ref<B,P,T> rhs)
{
return lhs.base() * rhs.base();
}
template<typename A>
auto operator*(ref<A> lhs, std::convertible_to<typename A::value_type> auto s)
template<typename A, int O, typename S>
auto operator*(ref<A,O,S> lhs, std::convertible_to<typename A::value_type> auto s)
{
return lhs.base() * s;
}
template<typename A>
auto operator*(std::convertible_to<typename A::value_type> auto s, ref<A> rhs)
template<typename A, int O, typename S>
auto operator*(std::convertible_to<typename A::value_type> auto s, ref<A,O,S> rhs)
{
return s * rhs.base();
}
template<typename A>
auto operator/(ref<A> lhs, std::convertible_to<typename A::value_type> auto s)
template<typename A, int O, typename S>
auto operator/(ref<A,O,S> lhs, std::convertible_to<typename A::value_type> auto s)
{
return lhs.base() / s;
}

View file

@ -19,9 +19,37 @@ namespace rotgen
#if !defined(ROTGEN_FORCE_DYNAMIC)
using stride = Eigen::Stride<-1,-1>;
#else
struct stride { Index outer, inner; };
struct stride
{
stride() : outer_(-1), inner_(1) {}
stride(Index s, Index i) : outer_(s), inner_(i) {}
Index inner() const { return inner_; }
Index outer() const { return outer_; }
private:
Index outer_;
Index inner_;
};
#endif
template<Index Value= Dynamic>
struct inner_stride : stride
{
inner_stride() : stride(-1,Value) {}
inner_stride(Index v) : stride(0, v) {}
};
template<Index Value = Dynamic>
struct outer_stride : stride
{
outer_stride() : stride(Value,0) {}
outer_stride(Index v) : stride(v,0) {}
};
inner_stride(Index) -> inner_stride<Dynamic>;
outer_stride(Index) -> outer_stride<Dynamic>;
template<int Order>
stride strides(Index r, Index c)
{
@ -29,6 +57,18 @@ namespace rotgen
else return {r,1};
}
template<int Order>
stride strides(stride original)
{
return original;
}
template<int, Index N>
stride strides(outer_stride<N> const& original)
{
return {original.outer(),1};
}
template<concepts::entity E>
auto strides(const E& e)
{

View file

@ -13,7 +13,7 @@
namespace rotgen
{
template<typename Ref, int Options = ColMajor, typename = void>
template<typename Ref, int Options = ColMajor, typename Stride = rotgen::stride>
class map : public find_map<Ref>
{
public:
@ -33,16 +33,16 @@ namespace rotgen
static constexpr bool is_defined_static = false;
using ptr_type = std::conditional_t<is_immutable, value_type const*, value_type*>;
using stride_type = stride;
using stride_type = Stride;
static constexpr Index RowsAtCompileTime = Ref::RowsAtCompileTime;
static constexpr Index ColsAtCompileTime = Ref::ColsAtCompileTime;
map(ptr_type ptr, Index r, Index c, stride_type s) : parent(ptr, r, c,s) {}
map(ptr_type ptr, Index r, Index c) : map(ptr, r, c, strides<storage_order>(r,c)) {}
map(ptr_type ptr, Index r, Index c, stride_type s) : parent(ptr, r, c, strides<storage_order>(s)) {}
map(ptr_type ptr, Index r, Index c) : parent(ptr, r, c, strides<storage_order>(r,c)) {}
map(ptr_type ptr, stride_type s) requires(RowsAtCompileTime!=-1 && ColsAtCompileTime!=-1)
: parent(ptr,RowsAtCompileTime,ColsAtCompileTime, s)
: parent(ptr,RowsAtCompileTime,ColsAtCompileTime, strides<storage_order>(s))
{}
map(ptr_type ptr, Index size) requires(RowsAtCompileTime==1 || ColsAtCompileTime==1)

View file

@ -29,17 +29,9 @@ namespace rotgen
template<typename Ref, int Options, bool isConst>
using map_type = typename compute_map_type<Ref,Options,isConst>::type;
template<typename T> struct map_stride;
template<typename PlainObjectType, int MapOptions, typename Stride>
struct map_stride<Eigen::Map<PlainObjectType, MapOptions, Stride>>
{
using type = Stride;
};
}
template<typename Ref, int Options = ColMajor, typename = void>
template<typename Ref, int Options = ColMajor, typename Stride = stride>
class map : private detail::map_type<std::remove_const_t<Ref>, Options, std::is_const_v<Ref>>
{
public:
@ -60,18 +52,18 @@ namespace rotgen
using as_concrete_type = as_concrete_t<ET, matrix>;
using ptr_type = std::conditional_t<is_immutable, value_type const*, value_type*>;
using stride_type = typename detail::map_stride<parent>::type;
using stride_type = Stride;
map(const map&) = default;
map(map&&) = default;
map& operator=(const map&) = default;
map& operator=(map&&) = default;
map(ptr_type ptr, Index r, Index c, stride_type s) : parent(ptr, r, c, s) {}
map(ptr_type ptr, Index r, Index c) : map(ptr, r, c, strides<storage_order>(r,c)) {}
map(ptr_type ptr, Index r, Index c, stride_type s) : parent(ptr, r, c, strides<storage_order>(s)) {}
map(ptr_type ptr, Index r, Index c) : parent(ptr, r, c, strides<storage_order>(r,c)) {}
map(ptr_type ptr, stride_type s) requires(RowsAtCompileTime!=-1 && ColsAtCompileTime!=-1)
: parent(ptr, s)
: parent(ptr, strides<storage_order>(s))
{}
map(ptr_type ptr, Index sz) requires(RowsAtCompileTime==1 || ColsAtCompileTime==1)