rotgen/include/rotgen/detail/accept_as_ref.hpp
Jules Pénuchot e151e136d6 Resolve "[API-#2] Pseudo-privatization of rotgen entity member functions"
Closes #18

Co-authored-by: Jules Pénuchot <jules@penuchot.com>

See merge request oss/rotgen!50
2025-12-17 20:48:00 +01:00

167 lines
5.9 KiB
C++

//==================================================================================================
/*
ROTGEN - Runtime Overlay for Eigen
Copyright : CODE RECKONS
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once
#include <rotgen/detail/assert.hpp>
#include <rotgen/functions/functions.hpp>
#include <concepts>
namespace rotgen::detail
{
// We split each sub-check into its own concept to improve
// diagnostic messages
template<typename Ref, typename Input>
concept same_scalar =
std::same_as<typename Ref::value_type, typename Input::value_type>;
template<typename Ref, typename Input>
concept any_is_vector =
Ref::IsVectorAtCompileTime || Input::IsVectorAtCompileTime;
template<typename Ref, typename Input>
concept compatible_storage =
any_is_vector<Ref, Input> || (Ref::storage_order == Input::storage_order);
template<typename Ref, typename Input>
concept compatible_inner_stride =
(Ref::InnerStrideAtCompileTime == rotgen::Dynamic) ||
(Ref::InnerStrideAtCompileTime == Input::InnerStrideAtCompileTime) ||
(Ref::InnerStrideAtCompileTime == 0 &&
Input::InnerStrideAtCompileTime == 1);
template<typename Ref, typename Input>
concept compatible_outer_stride =
any_is_vector<Ref, Input> ||
(Ref::OuterStrideAtCompileTime == rotgen::Dynamic) ||
(Ref::OuterStrideAtCompileTime == Input::OuterStrideAtCompileTime);
// Check what we can actually pass to ref<>
template<typename Ref, typename Input>
concept accept_as_ref =
same_scalar<Ref, Input> && compatible_storage<Ref, Input> &&
compatible_inner_stride<Ref, Input> && compatible_outer_stride<Ref, Input>;
// Local helpersto compute some strides related runtime values
constexpr Index proper_inner_stride(Index inner)
{
return inner == 0 ? 1 : inner;
}
constexpr Index proper_outer_stride(Index inner,
Index outer,
Index rows,
Index cols,
bool isVectorAtCompileTime,
bool isRowMajor)
{
return outer == 0 ? isVectorAtCompileTime ? inner * rows * cols
: isRowMajor ? inner * cols
: inner * rows
: outer;
}
// Helper used to runtime validate some properties of ref adn tells ref
// if a local copy and reconstruction needs to be done
template<typename Ref, typename Input> bool validate_ref(Ref& ref, Input& in)
{
using stride_type = typename Ref::stride_type;
using parent = typename Ref::parent;
using map_base = typename parent::parent;
auto r = rows(in);
auto c = cols(in);
if (Ref::RowsAtCompileTime == 1)
{
ROTGEN_ASSERT(rows(in) == 1 || cols(in) == 1,
"Incompatible rows/cols in ref binding");
r = 1;
c = in.size();
}
else if (Ref::ColsAtCompileTime == 1)
{
ROTGEN_ASSERT(rows(in) == 1 || cols(in) == 1,
"Incompatible rows/cols in ref binding");
r = in.size();
c = 1;
}
// Verify that the sizes are valid.
ROTGEN_ASSERT((Ref::RowsAtCompileTime == Dynamic) ||
(Ref::RowsAtCompileTime == r),
"Incompatible static rows/cols in ref binding");
ROTGEN_ASSERT((Ref::ColsAtCompileTime == Dynamic) ||
(Ref::ColsAtCompileTime == c),
"Incompatible static rows/cols in ref binding");
// Swap stride if we are a vector and we changed rows as such
bool transpose = Ref::IsVectorAtCompileTime && (r != rows(in));
// Swap stride if storage ordder doesn't match
constexpr bool row_major = Ref::IsRowMajor;
constexpr bool input_row_major = Input::IsRowMajor;
constexpr bool storage_differs = (row_major != input_row_major);
bool swap_stride = (transpose != storage_differs);
// Determine expr's actual strides, resolving any defaults if zero.
Index inner_actual = proper_inner_stride(innerStride(in));
Index outer_actual =
proper_outer_stride(inner_actual, outerStride(in), rows(in), cols(in),
Input::IsVectorAtCompileTime, input_row_major);
bool row_vector = (r == 1);
bool col_vector = (c == 1);
// Adapt inner stride based on row/col vector status
Index inner_stride =
((!row_major && row_vector) || (row_major && col_vector))
? (stride_type::InnerStrideAtCompileTime > 0
? stride_type::InnerStrideAtCompileTime
: 1)
: swap_stride ? outer_actual
: inner_actual;
// Adapt outer stride based on row/col vector status
Index outer_stride =
((!row_major && col_vector) || (row_major && row_vector))
? (stride_type::OuterStrideAtCompileTime > 0
? stride_type::OuterStrideAtCompileTime
: r * c * inner_stride)
: swap_stride ? inner_actual
: outer_actual;
// Validate compatibility of strides with compile-time strides
bool inner_valid =
(stride_type::InnerStrideAtCompileTime == Dynamic) ||
(proper_inner_stride(Index(stride_type::InnerStrideAtCompileTime)) ==
inner_stride);
if (!inner_valid) return false;
bool outer_valid =
(stride_type::OuterStrideAtCompileTime == Dynamic) ||
(proper_outer_stride(
inner_stride, Index(stride_type::OuterStrideAtCompileTime), r, c,
Ref::IsVectorAtCompileTime != 0, row_major) == outer_stride);
if (!outer_valid) return false;
dynamic_stride proper_stride(
stride_type::OuterStrideAtCompileTime == 0 ? 1 : outer_stride,
stride_type::InnerStrideAtCompileTime == 0 ? 1 : inner_stride);
auto actual = map_base(in.data(), r, c, proper_stride);
ref.base().base().storage().swap(actual.storage());
return true;
}
}