Fix some operations API mismatch

* Compound operators were not usable properly.
* std::size_t was used in the API in places where Index should have been used.
This commit is contained in:
Joel Falcou 2025-10-29 20:33:59 +01:00
parent f8cb289529
commit c7aa4a0afa
12 changed files with 152 additions and 120 deletions

View file

@ -87,62 +87,52 @@ namespace rotgen
return *this;
}
block(Ref const& r,
std::size_t i0,
std::size_t j0,
std::size_t ni,
std::size_t nj)
block(Ref const& r, Index i0, Index j0, Index ni, Index nj)
requires(!requires { typename Ref::rotgen_block_tag; } && is_immutable)
: parent(r.base(), i0, j0, ni, nj)
{
}
block(Ref const& r,
std::size_t i0,
std::size_t j0,
std::size_t ni,
std::size_t nj)
block(Ref const& r, Index i0, Index j0, Index ni, Index nj)
requires(requires { typename Ref::rotgen_block_tag; } && is_immutable)
: parent(r.base(), i0, j0, ni, nj)
{
}
block(Ref const& r, std::size_t i0, std::size_t j0)
block(Ref const& r, Index i0, Index j0)
requires(!requires { typename Ref::rotgen_block_tag; } && Rows != -1 &&
Cols != -1 && is_immutable)
: parent(r.base(), i0, j0, Rows, Cols)
{
}
block(Ref const& r, std::size_t i0, std::size_t j0)
block(Ref const& r, Index i0, Index j0)
requires(requires { typename Ref::rotgen_block_tag; } && Rows != -1 &&
Cols != -1 && is_immutable)
: parent(r.base(), i0, j0, Rows, Cols)
{
}
block(
Ref& r, std::size_t i0, std::size_t j0, std::size_t ni, std::size_t nj)
block(Ref& r, Index i0, Index j0, Index ni, Index nj)
requires(!requires { typename Ref::rotgen_block_tag; } && !is_immutable)
: parent(r.base(), i0, j0, ni, nj)
{
}
block(
Ref& r, std::size_t i0, std::size_t j0, std::size_t ni, std::size_t nj)
block(Ref& r, Index i0, Index j0, Index ni, Index nj)
requires(requires { typename Ref::rotgen_block_tag; } && !is_immutable)
: parent(r.base(), i0, j0, ni, nj)
{
}
block(Ref& r, std::size_t i0, std::size_t j0)
block(Ref& r, Index i0, Index j0)
requires(!requires { typename Ref::rotgen_block_tag; } && Rows != -1 &&
Cols != -1 && !is_immutable)
: parent(r.base(), i0, j0, Rows, Cols)
{
}
block(Ref& r, std::size_t i0, std::size_t j0)
block(Ref& r, Index i0, Index j0)
requires(requires { typename Ref::rotgen_block_tag; } && Rows != -1 &&
Cols != -1 && !is_immutable)
: parent(r.base(), i0, j0, Rows, Cols)

View file

@ -109,30 +109,25 @@ namespace rotgen
block& operator=(block const&) = default;
block& operator=(block&&) = default;
block(Ref const& r,
std::size_t i0,
std::size_t j0,
std::size_t ni,
std::size_t nj)
block(Ref const& r, Index i0, Index j0, Index ni, Index nj)
requires(is_immutable)
: parent(r.base(), i0, j0, ni, nj)
{
}
block(Ref const& r, std::size_t i0, std::size_t j0)
block(Ref const& r, Index i0, Index j0)
requires(Rows != -1 && Cols != -1 && is_immutable)
: parent(r.base(), i0, j0, Rows, Cols)
{
}
block(
Ref& r, std::size_t i0, std::size_t j0, std::size_t ni, std::size_t nj)
block(Ref& r, Index i0, Index j0, Index ni, Index nj)
requires(!is_immutable)
: parent(r.base(), i0, j0, ni, nj)
{
}
block(Ref& r, std::size_t i0, std::size_t j0)
block(Ref& r, Index i0, Index j0)
requires(Rows != -1 && Cols != -1 && !is_immutable)
: parent(r.base(), i0, j0, Rows, Cols)
{

View file

@ -94,7 +94,7 @@ namespace rotgen
"Mismatched between dynamic and static row size");
if constexpr (Cols != -1)
{
[[maybe_unused]] std::size_t c = 0;
[[maybe_unused]] Index c = 0;
if (init.size()) c = init.begin()->size();
assert(c == Cols &&
"Mismatched between dynamic and static column size");

View file

@ -15,10 +15,8 @@ class ROTGEN_EXPORT CLASSNAME
{
public:
CLASSNAME();
CLASSNAME(std::size_t rows, std::size_t cols);
CLASSNAME(std::size_t rows,
std::size_t cols,
std::initializer_list<TYPE> init);
CLASSNAME(Index rows, Index cols);
CLASSNAME(Index rows, Index cols, std::initializer_list<TYPE> init);
CLASSNAME(std::initializer_list<std::initializer_list<TYPE>> init);
@ -34,8 +32,8 @@ public:
Index cols() const;
Index size() const;
void resize(std::size_t new_rows, std::size_t new_cols);
void conservativeResize(std::size_t new_rows, std::size_t new_cols);
void resize(Index new_rows, Index new_cols);
void conservativeResize(Index new_rows, Index new_cols);
CLASSNAME normalized() const;
CLASSNAME transpose() const;
@ -64,11 +62,11 @@ public:
TYPE norm() const;
TYPE lp_norm(int p) const;
TYPE& operator()(std::size_t i, std::size_t j);
TYPE const& operator()(std::size_t i, std::size_t j) const;
TYPE& operator()(Index i, Index j);
TYPE const& operator()(Index i, Index j) const;
TYPE& operator()(std::size_t index);
TYPE const& operator()(std::size_t index) const;
TYPE& operator()(Index index);
TYPE const& operator()(Index index) const;
CLASSNAME& operator+=(CLASSNAME const& rhs);
CLASSNAME& operator-=(CLASSNAME const& rhs);
@ -89,17 +87,17 @@ public:
const TYPE* data() const;
TYPE* data();
static CLASSNAME Zero(std::size_t rows, std::size_t cols);
static CLASSNAME Ones(std::size_t rows, std::size_t cols);
static CLASSNAME Constant(std::size_t rows, std::size_t cols, TYPE value);
static CLASSNAME Random(std::size_t rows, std::size_t cols);
static CLASSNAME Identity(std::size_t rows, std::size_t cols);
static CLASSNAME Zero(Index rows, Index cols);
static CLASSNAME Ones(Index rows, Index cols);
static CLASSNAME Constant(Index rows, Index cols, TYPE value);
static CLASSNAME Random(Index rows, Index cols);
static CLASSNAME Identity(Index rows, Index cols);
void setOnes(std::size_t rows, std::size_t cols);
void setZero(std::size_t rows, std::size_t cols);
void setConstant(std::size_t rows, std::size_t cols, TYPE value);
void setRandom(std::size_t rows, std::size_t cols);
void setIdentity(std::size_t rows, std::size_t cols);
void setOnes(Index rows, Index cols);
void setZero(Index rows, Index cols);
void setConstant(Index rows, Index cols, TYPE value);
void setRandom(Index rows, Index cols);
void setIdentity(Index rows, Index cols);
private:
struct payload;

View file

@ -86,6 +86,45 @@ namespace rotgen
using parent::cwiseInverse;
using parent::cwiseSqrt;
// Compound Operators
template<typename A, int O, typename S>
ref& operator+=(ref<A, O, S> rhs)
requires(!is_immutable)
{
base() += rhs.base();
return *this;
}
template<typename A, int O, typename S>
ref& operator-=(ref<A, O, S> rhs)
requires(!is_immutable)
{
base() -= rhs.base();
return *this;
}
template<typename A, int O, typename S>
ref& operator*=(ref<A, O, S> rhs)
requires(!is_immutable)
{
base() *= rhs.base();
return *this;
}
ref& operator*=(std::convertible_to<value_type> auto s)
requires(!is_immutable)
{
base() *= s;
return *this;
}
ref& operator/=(std::convertible_to<value_type> auto s)
requires(!is_immutable)
{
base() /= s;
return *this;
}
// Shape modifications
using parent::adjoint;
using parent::conjugate;

View file

@ -132,6 +132,45 @@ namespace rotgen
using parent::sum;
using parent::trace;
// Compound Operators
template<typename A, int O, typename S>
ref& operator+=(ref<A, O, S> rhs)
requires(!is_immutable)
{
base() += rhs.base();
return *this;
}
template<typename A, int O, typename S>
ref& operator-=(ref<A, O, S> rhs)
requires(!is_immutable)
{
base() -= rhs.base();
return *this;
}
template<typename A, int O, typename S>
ref& operator*=(ref<A, O, S> rhs)
requires(!is_immutable)
{
base() *= rhs.base();
return *this;
}
ref& operator*=(std::convertible_to<value_type> auto s)
requires(!is_immutable)
{
base() *= s;
return *this;
}
ref& operator/=(std::convertible_to<value_type> auto s)
requires(!is_immutable)
{
base() /= s;
return *this;
}
// Shape modifications
using parent::adjoint;
using parent::conjugate;

View file

@ -23,35 +23,6 @@ namespace rotgen
return lhs.base() != rhs.base();
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator+=(ref<A, O, S> lhs, ref<B, P, T> rhs)
{
lhs.base() += rhs.base();
return lhs;
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator-=(ref<A, O, S> lhs, ref<B, P, T> rhs)
{
lhs.base() -= rhs.base();
return lhs;
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator*=(ref<A, O, S> lhs, ref<B, P, T> rhs)
{
lhs.base() *= rhs.base();
return lhs;
}
template<typename A, int O, typename S, typename B, int P, typename T>
auto operator/=(ref<A, O, S> lhs,
std::convertible_to<typename A::value_type> auto s)
{
lhs.base() /= s;
return lhs;
}
template<typename A, int O, typename S>
auto operator*(std::convertible_to<typename A::value_type> auto s,
ref<A, O, S> rhs)

View file

@ -26,7 +26,7 @@ namespace rotgen
data_type data;
payload(std::size_t r = 0, std::size_t c = 0) : data(r, c) {}
payload(Index r = 0, Index c = 0) : data(r, c) {}
payload(std::initializer_list<std::initializer_list<double>> init)
: data(init)
@ -49,7 +49,7 @@ namespace rotgen
data_type data;
payload(std::size_t r = 0, std::size_t c = 0) : data(r, c) {}
payload(Index r = 0, Index c = 0) : data(r, c) {}
payload(std::initializer_list<std::initializer_list<double>> init)
: data(init)
@ -72,7 +72,7 @@ namespace rotgen
data_type data;
payload(std::size_t r = 0, std::size_t c = 0) : data(r, c) {}
payload(Index r = 0, Index c = 0) : data(r, c) {}
payload(std::initializer_list<std::initializer_list<float>> init)
: data(init)
@ -95,7 +95,7 @@ namespace rotgen
data_type data;
payload(std::size_t r = 0, std::size_t c = 0) : data(r, c) {}
payload(Index r = 0, Index c = 0) : data(r, c) {}
payload(std::initializer_list<std::initializer_list<float>> init)
: data(init)

View file

@ -14,19 +14,19 @@ namespace rotgen
//-----------------------------------------------------------------------------------------------
// Infos & Shape
//-----------------------------------------------------------------------------------------------
std::size_t rows(auto const& m)
Index rows(auto const& m)
requires(requires { m.rows(); })
{
return m.rows();
}
std::size_t cols(auto const& m)
Index cols(auto const& m)
requires(requires { m.cols(); })
{
return m.cols();
}
std::size_t size(auto const& m)
Index size(auto const& m)
requires(requires { m.size(); })
{
return m.size();

View file

@ -8,6 +8,7 @@
#pragma once
#include <rotgen/concepts.hpp>
#include <cassert>
#include <iosfwd>
@ -81,7 +82,7 @@ namespace rotgen
// Compounds operators across types
template<typename A, typename B>
auto operator+=(A& a, B const& b)
requires(concepts::entity<A> || concepts::entity<B>)
requires(concepts::entity<A> && concepts::entity<B>)
{
if constexpr (!use_expression_templates)
return generalize_t<A>(a) += generalize_t<B const>(b);
@ -90,7 +91,7 @@ namespace rotgen
template<typename A, typename B>
auto operator-=(A& a, B const& b)
requires(concepts::entity<A> || concepts::entity<B>)
requires(concepts::entity<A> && concepts::entity<B>)
{
if constexpr (!use_expression_templates)
return generalize_t<A>(a) -= generalize_t<B const>(b);
@ -99,7 +100,7 @@ namespace rotgen
template<typename A, typename B>
auto operator*=(A& a, B const& b)
requires(concepts::entity<A> || concepts::entity<B>)
requires(concepts::entity<A> && concepts::entity<B>)
{
if constexpr (!use_expression_templates)
return generalize_t<A>(a) *= generalize_t<B const>(b);