Fix some ref/map/block non-trivial interactions.

This commit is contained in:
Joel Falcou 2025-09-17 09:51:46 +02:00
parent f285251a52
commit 399f17af57
7 changed files with 127 additions and 22 deletions

View file

@ -139,22 +139,21 @@ namespace rotgen
template<std::same_as<value_type> S, int R, int C, int O, int MR, int MC> template<std::same_as<value_type> S, int R, int C, int O, int MR, int MC>
ref(matrix<S, R, C, O, MR, MC> const& m) ref(matrix<S, R, C, O, MR, MC> const& m)
requires((O & 1) == storage_order)
: parent(m.data(), m.rows(), m.cols(), strides(m)) : parent(m.data(), m.rows(), m.cols(), strides(m))
{ {}
static_assert((O & 1) == storage_order, "ref: Incompatible storage layout");
}
template<typename Ref, int R, int C, bool I> template<typename Ref, int R, int C, bool I>
ref(block<Ref,R,C,I> const& b) : parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()}) ref ( block<Ref,R,C,I> const& b )
{ requires(std::same_as<value_type, typename Ref::value_type> && (Ref::storage_order & 1) == storage_order)
static_assert((Ref::storage_order & 1) == storage_order, "ref: Incompatible storage layout"); : parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()})
} {}
template<typename Ref, int O, typename S> template<typename Ref, int O, typename S>
ref(map<Ref,O,S> const& b) : parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()}) ref ( map<Ref,O,S> const& b )
{ requires(std::same_as<value_type, typename Ref::value_type> && (Ref::storage_order & 1) == storage_order)
static_assert((Ref::storage_order & 1) == storage_order, "ref: Incompatible storage layout"); : parent(b.data(), b.rows(), b.cols(), stride_type{b.outerStride(),b.innerStride()})
} {}
ref(parent const& m) : parent(m.data(), m.rows(), m.cols()) {} ref(parent const& m) : parent(m.data(), m.rows(), m.cols()) {}

View file

@ -20,6 +20,8 @@ namespace rotgen
#else #else
struct stride struct stride
{ {
static constexpr bool is_dynamic = true;
stride() : outer_(-1), inner_(1) {} stride() : outer_(-1), inner_(1) {}
stride(Index s, Index i) : outer_(s), inner_(i) {} stride(Index s, Index i) : outer_(s), inner_(i) {}
@ -35,6 +37,7 @@ namespace rotgen
template<Index Value= Dynamic> template<Index Value= Dynamic>
struct inner_stride : stride struct inner_stride : stride
{ {
static constexpr bool is_dynamic = Value == Dynamic;
inner_stride() : stride(-1,Value) {} inner_stride() : stride(-1,Value) {}
inner_stride(Index v) : stride(0, v) {} inner_stride(Index v) : stride(0, v) {}
}; };
@ -42,6 +45,7 @@ namespace rotgen
template<Index Value = Dynamic> template<Index Value = Dynamic>
struct outer_stride : stride struct outer_stride : stride
{ {
static constexpr bool is_dynamic = Value == Dynamic;
outer_stride() : stride(Value,0) {} outer_stride() : stride(Value,0) {}
outer_stride(Index v) : stride(v,0) {} outer_stride(Index v) : stride(v,0) {}
}; };
@ -57,15 +61,16 @@ namespace rotgen
} }
template<int Order> template<int Order>
stride strides(stride original) stride strides(stride const& original,Index, Index)
{ {
return original; return original;
} }
template<int, Index N> template<int Order, Index N>
stride strides(outer_stride<N> const& original) stride strides(outer_stride<N> const& original,Index r, Index c)
{ {
return {original.outer(),1}; if constexpr(N==0) return stride{ Order==ColMajor ? r : c, 1};
else return {original.outer(),1};
} }
template<concepts::entity E> template<concepts::entity E>

View file

@ -40,8 +40,16 @@ namespace rotgen
static constexpr Index RowsAtCompileTime = Ref::RowsAtCompileTime; static constexpr Index RowsAtCompileTime = Ref::RowsAtCompileTime;
static constexpr Index ColsAtCompileTime = Ref::ColsAtCompileTime; static constexpr Index ColsAtCompileTime = Ref::ColsAtCompileTime;
map(ptr_type ptr, Index r, Index c, stride_type s) : parent(ptr, r, c, strides<storage_order>(s)) {} map(ptr_type ptr, Index r, Index c, stride_type s) : parent(ptr, r, c, strides<storage_order>(s,r,c)) {}
map(ptr_type ptr, Index r, Index c) : parent(ptr, r, c, strides<storage_order>(r,c)) {} map(ptr_type ptr, Index r, Index c)
: parent( ptr, r, c
, [&]()
{
if constexpr(!std::same_as<Stride,stride>) return strides<storage_order>(Stride{},r,c);
else return strides<storage_order>(r,c);
}()
)
{}
map(ptr_type ptr, stride_type s) requires(RowsAtCompileTime!=-1 && ColsAtCompileTime!=-1) map(ptr_type ptr, stride_type s) requires(RowsAtCompileTime!=-1 && ColsAtCompileTime!=-1)
: parent(ptr,RowsAtCompileTime,ColsAtCompileTime, strides<storage_order>(s)) : parent(ptr,RowsAtCompileTime,ColsAtCompileTime, strides<storage_order>(s))

View file

@ -58,11 +58,19 @@ namespace rotgen
map& operator=(const map&) = default; map& operator=(const map&) = default;
map& operator=(map&&) = default; map& operator=(map&&) = default;
map(ptr_type ptr, Index r, Index c, stride_type s) : parent(ptr, r, c, strides<storage_order>(s)) {} map(ptr_type ptr, Index r, Index c, stride_type s) : parent(ptr, r, c, strides<storage_order>(s,r,c)) {}
map(ptr_type ptr, Index r, Index c) : parent(ptr, r, c, strides<storage_order>(r,c)) {} map(ptr_type ptr, Index r, Index c)
: parent( ptr, r, c
, [&]()
{
if constexpr(!std::same_as<Stride,stride>) return strides<storage_order>(Stride{},r,c);
else return strides<storage_order>(r,c);
}()
)
{}
map(ptr_type ptr, stride_type s) requires(RowsAtCompileTime!=-1 && ColsAtCompileTime!=-1) map(ptr_type ptr, stride_type s) requires(RowsAtCompileTime!=-1 && ColsAtCompileTime!=-1)
: parent(ptr, strides<storage_order>(s)) : parent(ptr, strides<storage_order>(s,RowsAtCompileTime,ColsAtCompileTime))
{} {}
map(ptr_type ptr, Index sz) requires(RowsAtCompileTime==1 || ColsAtCompileTime==1) map(ptr_type ptr, Index sz) requires(RowsAtCompileTime==1 || ColsAtCompileTime==1)

View file

@ -93,6 +93,7 @@ class ROTGEN_EXPORT CLASSNAME
const TYPE* data() const; const TYPE* data() const;
#if !defined(USE_CONST) #if !defined(USE_CONST)
TYPE* data();
void setZero(); void setZero();
void setOnes(); void setOnes();
void setRandom(); void setRandom();

View file

@ -51,7 +51,10 @@
rotgen::Index CLASSNAME::innerStride() const { return storage_->data.innerStride(); } rotgen::Index CLASSNAME::innerStride() const { return storage_->data.innerStride(); }
rotgen::Index CLASSNAME::outerStride() const { return storage_->data.outerStride(); } rotgen::Index CLASSNAME::outerStride() const { return storage_->data.outerStride(); }
const TYPE* CLASSNAME::data() const { return storage_->data.data(); }
#if !defined(USE_CONST) #if !defined(USE_CONST)
TYPE* CLASSNAME::data() { return storage_->data.data(); }
TYPE& CLASSNAME::operator()(Index i, Index j) { return storage_->data(i,j); } TYPE& CLASSNAME::operator()(Index i, Index j) { return storage_->data(i,j); }
TYPE& CLASSNAME::operator()(Index i) { return storage_->data.data()[i]; } TYPE& CLASSNAME::operator()(Index i) { return storage_->data.data()[i]; }
#endif #endif
@ -59,8 +62,6 @@
TYPE CLASSNAME::operator()(Index i, Index j) const { return storage_->data(i,j); } TYPE CLASSNAME::operator()(Index i, Index j) const { return storage_->data(i,j); }
TYPE CLASSNAME::operator()(Index i) const { return storage_->data.data()[i]; } TYPE CLASSNAME::operator()(Index i) const { return storage_->data.data()[i]; }
const TYPE* CLASSNAME::data() const { return storage_->data.data(); }
SOURCENAME CLASSNAME::transpose() const SOURCENAME CLASSNAME::transpose() const
{ {
SOURCENAME result; SOURCENAME result;

View file

@ -0,0 +1,83 @@
//==================================================================================================
/*
ROTGEN - Runtime Overlay for Eigen
Copyright : CODE RECKONS
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#include "unit/tests.hpp"
#include <rotgen/rotgen.hpp>
#include <iostream>
TTS_CASE_TPL("outer_stride<0> interactions", rotgen::tests::types)
<typename T, typename O>( tts::type< tts::types<T,O>> )
{
using mat_t = rotgen::matrix<T, rotgen::Dynamic, rotgen::Dynamic, O::value>;
T contiguous[] = {1,2,3, 4,5,6, 7,8,9, 10,11,12};
rotgen::map<mat_t, 0, rotgen::outer_stride<0>> m(&contiguous[0], 4, 3);
TTS_EQUAL(m.innerStride(), 1);
TTS_EQUAL(m.outerStride(), O::value == rotgen::ColMajor ? 4 : 3);
if constexpr(O::value == rotgen::ColMajor)
{
T padded[] = {1,2,3,4, 99, 5,6,7,8, 99,9,10,11,12};
rotgen::map<mat_t, 0, rotgen::outer_stride<5>> sp(&padded[0], 4, 3);
TTS_EQUAL(sp.innerStride(), 1);
TTS_EQUAL(sp.outerStride(), 5);
rotgen::map<mat_t, 0, rotgen::outer_stride<>> dp(&padded[0], 4, 3,rotgen::outer_stride(5));
TTS_EQUAL(dp.innerStride(), 1);
TTS_EQUAL(dp.outerStride(), 5);
TTS_EQUAL(m , sp);
TTS_EQUAL(m , dp);
TTS_EQUAL(dp, sp);
}
else
{
T padded[] = {1,2,3, 99, 4,5,6, 99, 7,8,9, 99, 10,11,12};
rotgen::map<mat_t, 0, rotgen::outer_stride<4>> sp(&padded[0], 4, 3);
TTS_EQUAL(sp.innerStride(), 1);
TTS_EQUAL(sp.outerStride(), 4);
rotgen::map<mat_t, 0, rotgen::outer_stride<>> dp(&padded[0], 4, 3,rotgen::outer_stride(4));
TTS_EQUAL(dp.innerStride(), 1);
TTS_EQUAL(dp.outerStride(), 4);
TTS_EQUAL(m , sp);
TTS_EQUAL(m , dp);
TTS_EQUAL(dp, sp);
}
};
void process_ref(rotgen::ref<const rotgen::matrix<float>> ) {}
void process_ref(rotgen::ref<const rotgen::matrix<double>>) {}
void process_ref(rotgen::ref<const rotgen::matrix<float , rotgen::Dynamic, rotgen::Dynamic, rotgen::RowMajor>>) {}
void process_ref(rotgen::ref<const rotgen::matrix<double, rotgen::Dynamic, rotgen::Dynamic, rotgen::RowMajor>>) {}
TTS_CASE_TPL("Extraction of outer_stride<?> blocks", rotgen::tests::types)
<typename T, typename O>( tts::type< tts::types<T,O>> )
{
using mat_t = rotgen::matrix<T, rotgen::Dynamic, rotgen::Dynamic, O::value>;
if constexpr(O::value == rotgen::ColMajor)
{
T padded[] = {1,2,3,4, 99, 5,6,7,8, 99,9,10,11,12};
rotgen::map<mat_t, 0, rotgen::outer_stride<5>> sp(&padded[0], 4, 3);
rotgen::map<mat_t, 0, rotgen::outer_stride<>> dp(&padded[0], 4, 3,rotgen::outer_stride(5));
TTS_EXPECT_COMPILES(sp, { process_ref(extract(sp,0, 0, 3, 2)); } );
TTS_EXPECT_COMPILES(dp, { process_ref(extract(dp,0, 0, 3, 2)); } );
}
else
{
T padded[] = {1,2,3, 99, 4,5,6, 99, 7,8,9, 99, 10,11,12};
rotgen::map<mat_t, 0, rotgen::outer_stride<4>> sp(&padded[0], 4, 3);
rotgen::map<mat_t, 0, rotgen::outer_stride<>> dp(&padded[0], 4, 3,rotgen::outer_stride(4));
TTS_EXPECT_COMPILES(sp, { process_ref(extract(sp,0, 0, 3, 2)); } );
TTS_EXPECT_COMPILES(dp, { process_ref(extract(dp,0, 0, 3, 2)); } );
}
};