Closes #18 Co-authored-by: Jules Pénuchot <jules@penuchot.com> See merge request oss/rotgen!50
167 lines
5.9 KiB
C++
167 lines
5.9 KiB
C++
//==================================================================================================
|
|
/*
|
|
ROTGEN - Runtime Overlay for Eigen
|
|
Copyright : CODE RECKONS
|
|
SPDX-License-Identifier: BSL-1.0
|
|
*/
|
|
//==================================================================================================
|
|
#pragma once
|
|
|
|
#include <rotgen/detail/assert.hpp>
|
|
|
|
#include <rotgen/functions/functions.hpp>
|
|
|
|
#include <concepts>
|
|
|
|
namespace rotgen::detail
|
|
{
|
|
// We split each sub-check into its own concept to improve
|
|
// diagnostic messages
|
|
template<typename Ref, typename Input>
|
|
concept same_scalar =
|
|
std::same_as<typename Ref::value_type, typename Input::value_type>;
|
|
|
|
template<typename Ref, typename Input>
|
|
concept any_is_vector =
|
|
Ref::IsVectorAtCompileTime || Input::IsVectorAtCompileTime;
|
|
|
|
template<typename Ref, typename Input>
|
|
concept compatible_storage =
|
|
any_is_vector<Ref, Input> || (Ref::storage_order == Input::storage_order);
|
|
|
|
template<typename Ref, typename Input>
|
|
concept compatible_inner_stride =
|
|
(Ref::InnerStrideAtCompileTime == rotgen::Dynamic) ||
|
|
(Ref::InnerStrideAtCompileTime == Input::InnerStrideAtCompileTime) ||
|
|
(Ref::InnerStrideAtCompileTime == 0 &&
|
|
Input::InnerStrideAtCompileTime == 1);
|
|
|
|
template<typename Ref, typename Input>
|
|
concept compatible_outer_stride =
|
|
any_is_vector<Ref, Input> ||
|
|
(Ref::OuterStrideAtCompileTime == rotgen::Dynamic) ||
|
|
(Ref::OuterStrideAtCompileTime == Input::OuterStrideAtCompileTime);
|
|
|
|
// Check what we can actually pass to ref<>
|
|
template<typename Ref, typename Input>
|
|
concept accept_as_ref =
|
|
same_scalar<Ref, Input> && compatible_storage<Ref, Input> &&
|
|
compatible_inner_stride<Ref, Input> && compatible_outer_stride<Ref, Input>;
|
|
|
|
// Local helpersto compute some strides related runtime values
|
|
constexpr Index proper_inner_stride(Index inner)
|
|
{
|
|
return inner == 0 ? 1 : inner;
|
|
}
|
|
|
|
constexpr Index proper_outer_stride(Index inner,
|
|
Index outer,
|
|
Index rows,
|
|
Index cols,
|
|
bool isVectorAtCompileTime,
|
|
bool isRowMajor)
|
|
{
|
|
return outer == 0 ? isVectorAtCompileTime ? inner * rows * cols
|
|
: isRowMajor ? inner * cols
|
|
: inner * rows
|
|
: outer;
|
|
}
|
|
|
|
// Helper used to runtime validate some properties of ref adn tells ref
|
|
// if a local copy and reconstruction needs to be done
|
|
template<typename Ref, typename Input> bool validate_ref(Ref& ref, Input& in)
|
|
{
|
|
using stride_type = typename Ref::stride_type;
|
|
using parent = typename Ref::parent;
|
|
using map_base = typename parent::parent;
|
|
|
|
auto r = rows(in);
|
|
auto c = cols(in);
|
|
|
|
if (Ref::RowsAtCompileTime == 1)
|
|
{
|
|
ROTGEN_ASSERT(rows(in) == 1 || cols(in) == 1,
|
|
"Incompatible rows/cols in ref binding");
|
|
r = 1;
|
|
c = in.size();
|
|
}
|
|
else if (Ref::ColsAtCompileTime == 1)
|
|
{
|
|
ROTGEN_ASSERT(rows(in) == 1 || cols(in) == 1,
|
|
"Incompatible rows/cols in ref binding");
|
|
r = in.size();
|
|
c = 1;
|
|
}
|
|
|
|
// Verify that the sizes are valid.
|
|
ROTGEN_ASSERT((Ref::RowsAtCompileTime == Dynamic) ||
|
|
(Ref::RowsAtCompileTime == r),
|
|
"Incompatible static rows/cols in ref binding");
|
|
ROTGEN_ASSERT((Ref::ColsAtCompileTime == Dynamic) ||
|
|
(Ref::ColsAtCompileTime == c),
|
|
"Incompatible static rows/cols in ref binding");
|
|
|
|
// Swap stride if we are a vector and we changed rows as such
|
|
bool transpose = Ref::IsVectorAtCompileTime && (r != rows(in));
|
|
|
|
// Swap stride if storage ordder doesn't match
|
|
constexpr bool row_major = Ref::IsRowMajor;
|
|
constexpr bool input_row_major = Input::IsRowMajor;
|
|
constexpr bool storage_differs = (row_major != input_row_major);
|
|
|
|
bool swap_stride = (transpose != storage_differs);
|
|
|
|
// Determine expr's actual strides, resolving any defaults if zero.
|
|
Index inner_actual = proper_inner_stride(innerStride(in));
|
|
Index outer_actual =
|
|
proper_outer_stride(inner_actual, outerStride(in), rows(in), cols(in),
|
|
Input::IsVectorAtCompileTime, input_row_major);
|
|
|
|
bool row_vector = (r == 1);
|
|
bool col_vector = (c == 1);
|
|
|
|
// Adapt inner stride based on row/col vector status
|
|
Index inner_stride =
|
|
((!row_major && row_vector) || (row_major && col_vector))
|
|
? (stride_type::InnerStrideAtCompileTime > 0
|
|
? stride_type::InnerStrideAtCompileTime
|
|
: 1)
|
|
: swap_stride ? outer_actual
|
|
: inner_actual;
|
|
|
|
// Adapt outer stride based on row/col vector status
|
|
Index outer_stride =
|
|
((!row_major && col_vector) || (row_major && row_vector))
|
|
? (stride_type::OuterStrideAtCompileTime > 0
|
|
? stride_type::OuterStrideAtCompileTime
|
|
: r * c * inner_stride)
|
|
: swap_stride ? inner_actual
|
|
: outer_actual;
|
|
|
|
// Validate compatibility of strides with compile-time strides
|
|
bool inner_valid =
|
|
(stride_type::InnerStrideAtCompileTime == Dynamic) ||
|
|
(proper_inner_stride(Index(stride_type::InnerStrideAtCompileTime)) ==
|
|
inner_stride);
|
|
|
|
if (!inner_valid) return false;
|
|
|
|
bool outer_valid =
|
|
(stride_type::OuterStrideAtCompileTime == Dynamic) ||
|
|
(proper_outer_stride(
|
|
inner_stride, Index(stride_type::OuterStrideAtCompileTime), r, c,
|
|
Ref::IsVectorAtCompileTime != 0, row_major) == outer_stride);
|
|
|
|
if (!outer_valid) return false;
|
|
|
|
dynamic_stride proper_stride(
|
|
stride_type::OuterStrideAtCompileTime == 0 ? 1 : outer_stride,
|
|
stride_type::InnerStrideAtCompileTime == 0 ? 1 : inner_stride);
|
|
|
|
auto actual = map_base(in.data(), r, c, proper_stride);
|
|
|
|
ref.base().base().storage().swap(actual.storage());
|
|
|
|
return true;
|
|
}
|
|
}
|