diff --git a/include/rotgen/common/ref.hpp b/include/rotgen/common/ref.hpp index 34ea890..656acf3 100644 --- a/include/rotgen/common/ref.hpp +++ b/include/rotgen/common/ref.hpp @@ -17,11 +17,11 @@ namespace rotgen { // Primary template: mutable ref - template - class ref : private map + template + class ref : private map { public: - using parent = map; + using parent = map; using value_type = typename T::value_type; using rotgen_tag = void; @@ -74,18 +74,6 @@ namespace rotgen static_assert((Ref::Options & 1) == Options, "ref: Incompatible storage layout"); } -// #if !defined(ROTGEN_FORCE_DYNAMIC) -// template -// ref(const Eigen::MatrixBase& b) -// : parent(b.data(), b.rows(), b.cols()/*, stride_type{b.outerStride(),b.innerStride()}*/) -// { -// using Str = typename parent::stride_type; -// std::cerr << "Stride compile-time: Outer=" << Str::OuterStrideAtCompileTime -// << " Inner=" << Str::InnerStrideAtCompileTime < -// ref(const Eigen::MatrixBase& b) -// : parent(b.data(), b.rows(), b.cols())//, stride_type{b.outerStride(),b.innerStride()}) -// { -// using Str = typename parent::stride_type; -// std::cerr << "Stride compile-time: Outer=" << Str::OuterStrideAtCompileTime -// << " Inner=" << Str::InnerStrideAtCompileTime < ref(block const& b) -> ref; - template - bool operator==(ref lhs, ref rhs) + template + bool operator==(ref lhs, ref rhs) { return lhs.base() == rhs.base(); } - template - bool operator!=(ref lhs, ref rhs) + template + bool operator!=(ref lhs, ref rhs) { return lhs.base() != rhs.base(); } - template - auto operator+(ref lhs, ref rhs) -> decltype(lhs.base() + rhs.base()) + template + auto operator+(ref lhs, ref rhs) -> decltype(lhs.base() + rhs.base()) { return lhs.base() + rhs.base(); } - template - auto operator-(ref lhs, ref rhs) + template + auto operator-(ref lhs, ref rhs) { return lhs.base() - rhs.base(); } - template - auto operator*(ref lhs, ref rhs) + template + auto operator*(ref lhs, ref rhs) { return lhs.base() * rhs.base(); } - template - auto operator*(ref lhs, std::convertible_to auto s) + template + auto operator*(ref lhs, std::convertible_to auto s) { return lhs.base() * s; } - template - auto operator*(std::convertible_to auto s, ref rhs) + template + auto operator*(std::convertible_to auto s, ref rhs) { return s * rhs.base(); } - template - auto operator/(ref lhs, std::convertible_to auto s) + template + auto operator/(ref lhs, std::convertible_to auto s) { return lhs.base() / s; } diff --git a/include/rotgen/common/strides.hpp b/include/rotgen/common/strides.hpp index 9a24878..381548d 100644 --- a/include/rotgen/common/strides.hpp +++ b/include/rotgen/common/strides.hpp @@ -19,9 +19,37 @@ namespace rotgen #if !defined(ROTGEN_FORCE_DYNAMIC) using stride = Eigen::Stride<-1,-1>; #else - struct stride { Index outer, inner; }; + struct stride + { + stride() : outer_(-1), inner_(1) {} + stride(Index s, Index i) : outer_(s), inner_(i) {} + + Index inner() const { return inner_; } + Index outer() const { return outer_; } + + private: + Index outer_; + Index inner_; + }; #endif + template + struct inner_stride : stride + { + inner_stride() : stride(-1,Value) {} + inner_stride(Index v) : stride(0, v) {} + }; + + template + struct outer_stride : stride + { + outer_stride() : stride(Value,0) {} + outer_stride(Index v) : stride(v,0) {} + }; + + inner_stride(Index) -> inner_stride; + outer_stride(Index) -> outer_stride; + template stride strides(Index r, Index c) { @@ -29,6 +57,18 @@ namespace rotgen else return {r,1}; } + template + stride strides(stride original) + { + return original; + } + + template + stride strides(outer_stride const& original) + { + return {original.outer(),1}; + } + template auto strides(const E& e) { diff --git a/include/rotgen/dynamic/map.hpp b/include/rotgen/dynamic/map.hpp index 3968db2..6f2818d 100644 --- a/include/rotgen/dynamic/map.hpp +++ b/include/rotgen/dynamic/map.hpp @@ -13,7 +13,7 @@ namespace rotgen { - template + template class map : public find_map { public: @@ -33,16 +33,16 @@ namespace rotgen static constexpr bool is_defined_static = false; using ptr_type = std::conditional_t; - using stride_type = stride; + using stride_type = Stride; static constexpr Index RowsAtCompileTime = Ref::RowsAtCompileTime; static constexpr Index ColsAtCompileTime = Ref::ColsAtCompileTime; - map(ptr_type ptr, Index r, Index c, stride_type s) : parent(ptr, r, c,s) {} - map(ptr_type ptr, Index r, Index c) : map(ptr, r, c, strides(r,c)) {} + map(ptr_type ptr, Index r, Index c, stride_type s) : parent(ptr, r, c, strides(s)) {} + map(ptr_type ptr, Index r, Index c) : parent(ptr, r, c, strides(r,c)) {} map(ptr_type ptr, stride_type s) requires(RowsAtCompileTime!=-1 && ColsAtCompileTime!=-1) - : parent(ptr,RowsAtCompileTime,ColsAtCompileTime, s) + : parent(ptr,RowsAtCompileTime,ColsAtCompileTime, strides(s)) {} map(ptr_type ptr, Index size) requires(RowsAtCompileTime==1 || ColsAtCompileTime==1) diff --git a/include/rotgen/fixed/map.hpp b/include/rotgen/fixed/map.hpp index 87b1d63..0f2e8fd 100644 --- a/include/rotgen/fixed/map.hpp +++ b/include/rotgen/fixed/map.hpp @@ -29,17 +29,9 @@ namespace rotgen template using map_type = typename compute_map_type::type; - - template struct map_stride; - - template - struct map_stride> - { - using type = Stride; - }; } - template + template class map : private detail::map_type, Options, std::is_const_v> { public: @@ -60,18 +52,18 @@ namespace rotgen using as_concrete_type = as_concrete_t; using ptr_type = std::conditional_t; - using stride_type = typename detail::map_stride::type; + using stride_type = Stride; map(const map&) = default; map(map&&) = default; map& operator=(const map&) = default; map& operator=(map&&) = default; - map(ptr_type ptr, Index r, Index c, stride_type s) : parent(ptr, r, c, s) {} - map(ptr_type ptr, Index r, Index c) : map(ptr, r, c, strides(r,c)) {} + map(ptr_type ptr, Index r, Index c, stride_type s) : parent(ptr, r, c, strides(s)) {} + map(ptr_type ptr, Index r, Index c) : parent(ptr, r, c, strides(r,c)) {} map(ptr_type ptr, stride_type s) requires(RowsAtCompileTime!=-1 && ColsAtCompileTime!=-1) - : parent(ptr, s) + : parent(ptr, strides(s)) {} map(ptr_type ptr, Index sz) requires(RowsAtCompileTime==1 || ColsAtCompileTime==1) diff --git a/src/map_model.cpp b/src/map_model.cpp index 0e9418a..2f4d834 100644 --- a/src/map_model.cpp +++ b/src/map_model.cpp @@ -23,7 +23,7 @@ {} CLASSNAME::CLASSNAME(TYPE CONST* ptr, Index r, Index c, stride s) - : storage_(std::make_unique(ptr,r,c,payload::stride_type{s.outer,s.inner})) + : storage_(std::make_unique(ptr,r,c,payload::stride_type{s.outer(),s.inner()})) {} CLASSNAME::CLASSNAME(CLASSNAME const& o) : storage_(std::make_unique(o.storage_->data)) diff --git a/test/unit/map/strides.cpp b/test/unit/map/strides.cpp new file mode 100644 index 0000000..c6900c5 --- /dev/null +++ b/test/unit/map/strides.cpp @@ -0,0 +1,146 @@ +//================================================================================================== +/* + ROTGEN - Runtime Overlay for Eigen + Copyright : CODE RECKONS + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#include "unit/tests.hpp" +#include +#include +#include + +auto generate_data(int rows, int cols) +{ + std::vector buffer(rows * cols * 3); + + for (size_t i = 0; i < buffer.size(); ++i) + buffer[i] = static_cast(1+i); + + return buffer; +} + +template +using r_mat_t = rotgen::matrix; + +template +using e_mat_t = Eigen::Matrix; + +template using r_map_t = rotgen::map; +template> using e_map_t = Eigen::Map; + +TTS_CASE("Validate Column Major Map with regular stride behavior") +{ + auto rows = 3; + auto cols = 4; + auto buffer = generate_data(rows,cols); + + r_map_t> r_map(buffer.data(), rows, cols); + + TTS_EQUAL(r_map.innerStride(), 1); + TTS_EQUAL(r_map.outerStride(), 3); + + e_map_t> e_map(buffer.data(), rows, cols); + + for(std::ptrdiff_t r=0;r, rotgen::outer_stride<>> + r_map(buffer.data(), rows, cols, rotgen::outer_stride(rows + 1)); + + TTS_EQUAL(r_map.innerStride(), 1); + TTS_EQUAL(r_map.outerStride() , 4); + + e_map_t, Eigen::OuterStride<>> + e_map(buffer.data(), rows, cols, Eigen::OuterStride<>(rows + 1)); + + for(std::ptrdiff_t r=0;r> + r_map(buffer.data(), rows, cols, rotgen::stride(rows, 2)); + + TTS_EQUAL(r_map.innerStride(), 2); + TTS_EQUAL(r_map.outerStride() , 3); + + e_map_t,Eigen::Stride> + e_map(buffer.data(), rows, cols, Eigen::Stride(rows,2)); + + for(std::ptrdiff_t r=0;r> r_map(buffer.data(), rows, cols); + + TTS_EQUAL(r_map.innerStride(), 1); + TTS_EQUAL(r_map.outerStride(), 4); + + e_map_t> e_map(buffer.data(), rows, cols); + + for(std::ptrdiff_t r=0;r, rotgen::outer_stride<>> + r_map(buffer.data(), rows, cols, rotgen::outer_stride(cols + 1)); + + TTS_EQUAL(r_map.innerStride(), 1); + TTS_EQUAL(r_map.outerStride() , 5); + + e_map_t, Eigen::OuterStride<>> + e_map(buffer.data(), rows, cols, Eigen::OuterStride<>(cols + 1)); + + for(std::ptrdiff_t r=0;r,rotgen::stride> + r_map(buffer.data(), rows, cols, rotgen::stride(2, cols)); + + TTS_EQUAL(r_map.innerStride(), 4); + TTS_EQUAL(r_map.outerStride() , 2); + + e_map_t,Eigen::Stride> + e_map(buffer.data(), rows, cols, Eigen::Stride(2,cols)); + + for(std::ptrdiff_t r=0;r