From 3d23a07e90f332b5a2e74bb8d556eaaaa352afa8 Mon Sep 17 00:00:00 2001 From: Karen Kaspar Date: Mon, 23 Jun 2025 15:22:11 +0200 Subject: [PATCH] Feat/block implementation See merge request oss/rotgen!10 --- include/rotgen/block.hpp | 298 ++++++++++++----------- include/rotgen/impl/block_model.hpp | 16 +- include/rotgen/impl/payload.hpp | 8 + src/block_model.cpp | 211 +++++++++------- test/unit/block/arithmetic_functions.cpp | 58 ++++- test/unit/block/operators.cpp | 102 ++++---- 6 files changed, 406 insertions(+), 287 deletions(-) diff --git a/include/rotgen/block.hpp b/include/rotgen/block.hpp index 5c1c963..e082471 100644 --- a/include/rotgen/block.hpp +++ b/include/rotgen/block.hpp @@ -35,163 +35,183 @@ namespace rotgen block(parent const& base) : parent(base) {} -/* -block transpose() const -{ - return block(static_cast(*this).transpose()); -} + concrete_type transpose() const + { + return concrete_type(static_cast(*this).transpose()); + } -block conjugate() const -{ - return block(static_cast(*this).conjugate()); -} + concrete_type conjugate() const + { + return concrete_type(static_cast(*this).conjugate()); + } -block adjoint() const -{ - return block(static_cast(*this).adjoint()); -} -*/ + concrete_type adjoint() const + { + return concrete_type(static_cast(*this).adjoint()); + } - friend bool operator==(block const& lhs, block const& rhs) + void transposeInPlace() { parent::transposeInPlace(); } + + void adjointInPlace() { parent::adjointInPlace(); } + + friend bool operator==(block const& lhs, block const& rhs) + { + return static_cast(lhs) == static_cast(rhs); + } + + block& operator+=(block const& rhs) + { + static_cast(*this) += static_cast(rhs); + return *this; + } + + block& operator-=(block const& rhs) + { + static_cast(*this) -= static_cast(rhs); + return *this; + } + + concrete_type operator-() const + { + return concrete_type(static_cast(*this).operator-()); + } + + concrete_type matrix_addition(block const& rhs) const + { + return concrete_type(static_cast(*this).matrix_addition(rhs)); + } + + concrete_type matrix_subtraction(block const& rhs) const + { + return concrete_type(static_cast(*this).matrix_subtraction(rhs)); + } + + concrete_type matrix_multiplication(block const& rhs) const + { + return concrete_type(static_cast(*this).matrix_multiplication(rhs)); + } + + concrete_type matrix_multiplication(scalar_type rhs) const + { + return concrete_type(static_cast(*this).matrix_multiplication(rhs)); + } + + concrete_type matrix_division(scalar_type rhs) const + { + return concrete_type(static_cast(*this).matrix_division(rhs)); + } + + block& operator*=(block const& rhs) + { + static_cast(*this) *= static_cast(rhs); + return *this; + } + + block& operator*=(scalar_type rhs) + { + static_cast(*this) *= rhs; + return *this; + } + + block& operator/=(scalar_type rhs) + { + static_cast(*this) /= rhs; + return *this; + } + + static concrete_type Zero() + requires (Rows != -1 && Cols != -1) + { + return parent::Zero(Rows, Cols); + } + + static concrete_type Zero(int rows, int cols) + { + if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size"); + if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size"); + return parent::Zero(rows, cols); + } + + static concrete_type Constant(scalar_type value) + requires (Rows != -1 && Cols != -1) + { + return parent::Constant(Rows, Cols, static_cast(value)); + } + + static concrete_type Constant(int rows, int cols, scalar_type value) + { + if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size"); + if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size"); + return parent::Constant(rows, cols, static_cast(value)); + } + + static concrete_type Random() + requires (Rows != -1 && Cols != -1) + { + return parent::Random(Rows, Cols); + } + + static concrete_type Random(int rows, int cols) + { + if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size"); + if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size"); + return parent::Random(rows, cols); + } + + static concrete_type Identity() + requires (Rows != -1 && Cols != -1) + { + return parent::Identity(Rows, Cols); + } + + static concrete_type Identity(int rows, int cols) + { + if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size"); + if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size"); + return parent::Identity(rows, cols); + } + + template + double lpNorm() const + { + assert(P == 1 || P == 2 || P == Infinity); + return parent::lpNorm(P); + } + }; + + template + typename block::concrete_type operator+(block const& lhs, block const& rhs) { - return static_cast(lhs) == static_cast(rhs); + return lhs.matrix_addition(rhs); } - block& operator+=(block const& rhs) + template + typename block::concrete_type operator-(block const& lhs, block const& rhs) { - static_cast(*this) += static_cast(rhs); - return *this; + return lhs.matrix_subtraction(rhs); } - block& operator-=(block const& rhs) + template + typename block::concrete_type operator*(block const& lhs, block const& rhs) { - static_cast(*this) -= static_cast(rhs); - return *this; + return lhs.matrix_multiplication(rhs); } - concrete_type operator-() const + template + typename block::concrete_type operator*(block const& lhs, double rhs) { - return concrete_type(static_cast(*this).operator-()); + return lhs.matrix_multiplication(rhs); } - block& operator*=(block const& rhs) + template + typename block::concrete_type operator*(double lhs, block const& rhs) { - static_cast(*this) *= static_cast(rhs); - return *this; + return rhs.matrix_multiplication(lhs); } - block& operator*=(scalar_type rhs) + template + typename block::concrete_type operator/(block const& lhs, double rhs) { - static_cast(*this) *= rhs; - return *this; + return lhs.matrix_division(rhs); } - - block& operator/=(scalar_type rhs) - { - static_cast(*this) /= rhs; - return *this; - } - - static concrete_type Zero() - requires (Rows != -1 && Cols != -1) - { - return parent::Zero(Rows, Cols); - } - - static concrete_type Zero(int rows, int cols) - { - if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size"); - if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size"); - return parent::Zero(rows, cols); - } - - static concrete_type Constant(scalar_type value) - requires (Rows != -1 && Cols != -1) - { - return parent::Constant(Rows, Cols, static_cast(value)); - } - - static concrete_type Constant(int rows, int cols, scalar_type value) - { - if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size"); - if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size"); - return parent::Constant(rows, cols, static_cast(value)); - } - - static concrete_type Random() - requires (Rows != -1 && Cols != -1) - { - return parent::Random(Rows, Cols); - } - - static concrete_type Random(int rows, int cols) - { - if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size"); - if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size"); - return parent::Random(rows, cols); - } - - static concrete_type Identity() - requires (Rows != -1 && Cols != -1) - { - return parent::Identity(Rows, Cols); - } - - static concrete_type Identity(int rows, int cols) - { - if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size"); - if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size"); - return parent::Identity(rows, cols); - } - - template - double lpNorm() const - { - assert(P == 1 || P == 2 || P == Infinity); - return parent::lpNorm(P); - } -}; - -/* -template -block operator+(block const& lhs, block const& rhs) -{ - block that(lhs); - return that += rhs; -} - -template -block operator-(block const& lhs, block const& rhs) -{ - block that(lhs); - return that -= rhs; -} - -template -block operator*(block const& lhs, block const& rhs) -{ - block that(lhs); - return that *= rhs; - } - - template - block operator*(block const& lhs, double rhs) - { - block that(lhs); - return that *= rhs; - } - - template - block operator*(double lhs, block const& rhs) - { - return rhs * lhs; - } - - template - block operator/(block const& lhs, double rhs) - { - block that(lhs); - return that /= rhs; - } -*/ } \ No newline at end of file diff --git a/include/rotgen/impl/block_model.hpp b/include/rotgen/impl/block_model.hpp index c68db37..52ecb99 100644 --- a/include/rotgen/impl/block_model.hpp +++ b/include/rotgen/impl/block_model.hpp @@ -28,12 +28,12 @@ class CLASSNAME std::size_t cols() const; std::size_t size() const; - // CLASSNAME transpose() const; - // CLASSNAME conjugate() const; - // CLASSNAME adjoint() const; + SOURCENAME transpose() const; + SOURCENAME conjugate() const; + SOURCENAME adjoint() const; - // void transposeInPlace(); - // void adjointInPlace(); + void transposeInPlace(); + void adjointInPlace(); TYPE sum() const; TYPE prod() const; @@ -61,6 +61,12 @@ class CLASSNAME CLASSNAME& operator*=(TYPE d); CLASSNAME& operator/=(TYPE d); + SOURCENAME matrix_addition(CLASSNAME const& rhs) const; + SOURCENAME matrix_subtraction(CLASSNAME const& rhs) const; + SOURCENAME matrix_multiplication(CLASSNAME const& rhs) const; + SOURCENAME matrix_multiplication(TYPE s) const; + SOURCENAME matrix_division(TYPE s) const; + friend std::ostream& operator<<(std::ostream&,CLASSNAME const&); friend bool operator==(CLASSNAME const& lhs, CLASSNAME const& rhs); diff --git a/include/rotgen/impl/payload.hpp b/include/rotgen/impl/payload.hpp index 9685228..b1af8c1 100644 --- a/include/rotgen/impl/payload.hpp +++ b/include/rotgen/impl/payload.hpp @@ -27,6 +27,8 @@ namespace rotgen payload(data_type&& matrix) : data(std::move(matrix)) {} void assign(Eigen::Block ref) { data = ref; } + void assign(data_type const& mat) { data = mat; } + void assign(data_type&& mat) { data = std::move(mat); } }; struct matrix_impl64_row::payload @@ -39,6 +41,8 @@ namespace rotgen payload(data_type&& matrix) : data(std::move(matrix)) {} void assign(Eigen::Block ref) { data = ref; } + void assign(data_type const& mat) { data = mat; } + void assign(data_type&& mat) { data = std::move(mat); } }; struct matrix_impl32_col::payload @@ -51,6 +55,8 @@ namespace rotgen payload(data_type&& matrix) : data(std::move(matrix)) {} void assign(Eigen::Block ref) { data = ref; } + void assign(data_type const& mat) { data = mat; } + void assign(data_type&& mat) { data = std::move(mat); } }; struct matrix_impl32_row::payload @@ -63,5 +69,7 @@ namespace rotgen payload(data_type&& matrix) : data(std::move(matrix)) {} void assign(Eigen::Block ref) { data = ref; } + void assign(data_type const& mat) { data = mat; } + void assign(data_type&& mat) { data = std::move(mat); } }; } diff --git a/src/block_model.cpp b/src/block_model.cpp index 1756a50..79a907c 100644 --- a/src/block_model.cpp +++ b/src/block_model.cpp @@ -21,6 +21,7 @@ struct CLASSNAME::payload using data_type = Eigen::Block; data_type data; + payload (data_type const& o) : data(o) {} payload (base_type& r, std::size_t i0, std::size_t j0, std::size_t ni, std::size_t nj) @@ -59,114 +60,144 @@ struct CLASSNAME::payload TYPE& CLASSNAME::operator()(std::size_t i, std::size_t j) { return storage_->data(i,j); } TYPE const& CLASSNAME::operator()(std::size_t i, std::size_t j) const { return storage_->data(i,j); } -/* -TYPE& CLASSNAME::operator()(std::size_t index) { return storage_->data(index); } -TYPE const& CLASSNAME::operator()(std::size_t index) const { return storage_->data(index); } -*/ + /* + TYPE& CLASSNAME::operator()(std::size_t index) { return storage_->data(index); } + TYPE const& CLASSNAME::operator()(std::size_t index) const { return storage_->data(index); } + */ const TYPE* CLASSNAME::data() const { return storage_->data.data(); } -/* -CLASSNAME CLASSNAME::transpose() const -{ - CLASSNAME result(*this); - result.storage_->data.transposeInPlace(); - return result; + SOURCENAME CLASSNAME::transpose() const + { + SOURCENAME result; + result.storage()->assign(storage_->data.transpose().eval()); + return result; + } -} + SOURCENAME CLASSNAME::conjugate() const { + SOURCENAME result; + result.storage()->assign(storage_->data.conjugate().eval()); + return result; + } -CLASSNAME CLASSNAME::conjugate() const -{ - CLASSNAME result(*this); - result.storage_->data = storage_->data.conjugate(); - return result; -} + SOURCENAME CLASSNAME::adjoint() const { + SOURCENAME result; + result.storage()->assign(storage_->data.adjoint().eval()); + return result; + } -CLASSNAME CLASSNAME::adjoint() const -{ - CLASSNAME result(*this); - result.storage_->data.adjointInPlace(); - return result; -} + void CLASSNAME::transposeInPlace() + { + storage_->data.transposeInPlace(); + } -void CLASSNAME::transposeInPlace() -{ - storage_->data.transposeInPlace(); -} + void CLASSNAME::adjointInPlace() + { + storage_->data.adjointInPlace(); + } -void CLASSNAME::adjointInPlace() -{ - storage_->data.adjointInPlace(); -} -*/ + TYPE CLASSNAME::sum() const { return storage_->data.sum(); } + TYPE CLASSNAME::prod() const { return storage_->data.prod(); } + TYPE CLASSNAME::mean() const { return storage_->data.mean(); } + TYPE CLASSNAME::trace() const { return storage_->data.trace(); } -TYPE CLASSNAME::sum() const { return storage_->data.sum(); } -TYPE CLASSNAME::prod() const { return storage_->data.prod(); } -TYPE CLASSNAME::mean() const { return storage_->data.mean(); } -TYPE CLASSNAME::trace() const { return storage_->data.trace(); } + TYPE CLASSNAME::minCoeff() const { return storage_->data.minCoeff(); } + TYPE CLASSNAME::maxCoeff() const { return storage_->data.maxCoeff(); } -TYPE CLASSNAME::minCoeff() const { return storage_->data.minCoeff(); } -TYPE CLASSNAME::maxCoeff() const { return storage_->data.maxCoeff(); } + TYPE CLASSNAME::minCoeff(std::ptrdiff_t* row, std::ptrdiff_t* col) const { return storage_->data.minCoeff(row, col); } + TYPE CLASSNAME::maxCoeff(std::ptrdiff_t* row, std::ptrdiff_t* col) const { return storage_->data.maxCoeff(row, col); } -TYPE CLASSNAME::minCoeff(std::ptrdiff_t* row, std::ptrdiff_t* col) const { return storage_->data.minCoeff(row, col); } -TYPE CLASSNAME::maxCoeff(std::ptrdiff_t* row, std::ptrdiff_t* col) const { return storage_->data.maxCoeff(row, col); } + TYPE CLASSNAME::squaredNorm() const { return storage_->data.squaredNorm(); } + TYPE CLASSNAME::norm() const { return storage_->data.norm(); } -TYPE CLASSNAME::squaredNorm() const { return storage_->data.squaredNorm(); } -TYPE CLASSNAME::norm() const { return storage_->data.norm(); } + TYPE CLASSNAME::lpNorm(int p) const + { + if (p == 1) return storage_->data.lpNorm<1>(); + else if (p == 2) return storage_->data.lpNorm<2>(); + else return storage_->data.lpNorm(); + } -TYPE CLASSNAME::lpNorm(int p) const -{ - if (p == 1) return storage_->data.lpNorm<1>(); - else if (p == 2) return storage_->data.lpNorm<2>(); - else return storage_->data.lpNorm(); -} + //================================================================================================== + // Operators + //================================================================================================== + std::ostream& operator<<(std::ostream& os,CLASSNAME const& m) + { + return os << m.storage_->data; + } -//================================================================================================== -// Operators -//================================================================================================== -std::ostream& operator<<(std::ostream& os,CLASSNAME const& m) -{ - return os << m.storage_->data; -} + bool operator==(CLASSNAME const& lhs, CLASSNAME const& rhs) + { + return lhs.storage_->data == rhs.storage_->data; + } -bool operator==(CLASSNAME const& lhs, CLASSNAME const& rhs) -{ - return lhs.storage_->data == rhs.storage_->data; -} + CLASSNAME& CLASSNAME::operator+=(CLASSNAME const& rhs) + { + storage_->data += rhs.storage_->data; + return *this; + } -CLASSNAME& CLASSNAME::operator+=(CLASSNAME const& rhs) -{ - storage_->data += rhs.storage_->data; - return *this; -} + CLASSNAME& CLASSNAME::operator-=(CLASSNAME const& rhs) + { + storage_->data -= rhs.storage_->data; + return *this; + } -CLASSNAME& CLASSNAME::operator-=(CLASSNAME const& rhs) -{ - storage_->data -= rhs.storage_->data; - return *this; -} + SOURCENAME CLASSNAME::operator-() const + { + SOURCENAME result; + result.storage()->assign(storage_->data); + return -result; + } -SOURCENAME CLASSNAME::operator-() const -{ - SOURCENAME result; - result.storage()->assign(storage_->data); - return -result; -} + CLASSNAME& CLASSNAME::operator*=(CLASSNAME const& rhs) + { + storage_->data *= rhs.storage_->data; + return *this; + } -CLASSNAME& CLASSNAME::operator*=(CLASSNAME const& rhs) -{ - storage_->data *= rhs.storage_->data; - return *this; -} + CLASSNAME& CLASSNAME::operator*=(TYPE s) + { + storage_->data *= s; + return *this; + } -CLASSNAME& CLASSNAME::operator*=(TYPE s) -{ - storage_->data *= s; - return *this; -} + CLASSNAME& CLASSNAME::operator/=(TYPE s) + { + storage_->data /= s; + return *this; + } -CLASSNAME& CLASSNAME::operator/=(TYPE s) -{ - storage_->data /= s; - return *this; -} + SOURCENAME CLASSNAME::matrix_addition(CLASSNAME const& rhs) const + { + SOURCENAME result; + result.storage()->assign(storage_->data + rhs.storage_->data); + return result; + } + + SOURCENAME CLASSNAME::matrix_subtraction(CLASSNAME const& rhs) const + { + SOURCENAME result; + result.storage()->assign(storage_->data - rhs.storage_->data); + return result; + } + + SOURCENAME CLASSNAME::matrix_multiplication(CLASSNAME const& rhs) const + { + SOURCENAME result; + result.storage()->assign(storage_->data * rhs.storage_->data); + return result; + } + + SOURCENAME CLASSNAME::matrix_multiplication(TYPE s) const + { + SOURCENAME result; + result.storage()->assign(storage_->data * s); + return result; + } + + SOURCENAME CLASSNAME::matrix_division(TYPE s) const + { + SOURCENAME result; + result.storage()->assign(storage_->data / s); + return result; + } \ No newline at end of file diff --git a/test/unit/block/arithmetic_functions.cpp b/test/unit/block/arithmetic_functions.cpp index 4609d44..02e4f27 100644 --- a/test/unit/block/arithmetic_functions.cpp +++ b/test/unit/block/arithmetic_functions.cpp @@ -11,6 +11,37 @@ #include #include +template +void test_comparison(const Type1& t1, const Type2& t2) +{ + TTS_EQUAL(static_cast(t1.rows()), static_cast(t2.rows())); + TTS_EQUAL(static_cast(t1.cols()), static_cast(t2.cols())); + for (std::size_t r = 0; r < static_cast(t1.rows()); ++r) + for (std::size_t c = 0; c < static_cast(t1.cols()); ++c) + TTS_EQUAL(t1(r, c), t2(r, c)); +} + +template +void test_block_unary_ops(const Matrix1& original_matrix, const Matrix2& ref_matrix, + Block1 original_block, Block2 ref_block) +{ + test_comparison(original_block.transpose(), ref_block.transpose()); + test_comparison(original_block.conjugate(), ref_block.conjugate()); + test_comparison(original_block.adjoint(), ref_block.adjoint()); + + if (original_block.rows() == original_block.cols()) { + original_block.transposeInPlace(); + ref_block.transposeInPlace(); + test_comparison(original_block, ref_block); + test_comparison(original_matrix, ref_matrix); + + original_block.adjointInPlace(); + ref_block.adjointInPlace(); + test_comparison(original_block, ref_block); + test_comparison(original_matrix, ref_matrix); + } +} + template void compare_reductions(const Block1& block, const Block2& ref) { @@ -68,8 +99,11 @@ void test_dynamic_block_reductions(rotgen::tests::matrix_block_test_case> test_cases = { - {6, 5, [](auto r, auto c) {return r + c; }, 1, 2, 3, 2}, - {9, 11, [](auto r, auto c) {return r + c; }, 0, 1, 4, 9}, - {3, 3, [](auto , auto ) {return 0.0; }, 1, 1, 1, 1}, - {1, 4, [](auto r, auto c) {return -r -c*c - 1234; }, 0, 0, 1, 1}, - {4, 1, [](auto , auto ) {return 7.0; }, 2, 0, 2, 1}, - {1, 1, [](auto , auto ) {return 42.0; }, 0, 0, 1, 1}, - {12, 13, [](auto r, auto c) {return std::sin(r + c); }, 2, 3, 4, 5 }, - {4, 9, [](auto r, auto c) {return -1.5 * r + 2.56 * c; }, 0, 1, 2, 3 }, - {2, 5, [](auto r, auto c) {return (r == c ? 1.0 : 0.0); }, 1, 1, 1, 1}, + {6, 5, [](auto r, auto c) { return T(r + c); }, 1, 2, 3, 2}, + {9, 11, [](auto r, auto c) {return T(r + c); }, 0, 1, 4, 9}, + {3, 3, [](auto , auto ) {return T(0.0); }, 1, 1, 1, 1}, + {1, 4, [](auto r, auto c) {return T(-r -c*c - 1234); }, 0, 0, 1, 1}, + {9, 9, [](auto r, auto c) {return T(-r + 2*c); }, 0, 1, 3, 3}, + {11, 13, [](auto r, auto c) {return T(std::tan(r+c)); }, 1, 1, 2, 2}, + {4, 1, [](auto , auto ) {return T(7.0); }, 2, 0, 2, 1}, + {1, 1, [](auto , auto ) {return T(42.0); }, 0, 0, 1, 1}, + {12, 13, [](auto r, auto c) {return T(std::sin(r + c)); }, 2, 3, 4, 5 }, + {4, 9, [](auto r, auto c) {return T(-1.5 * r + 2.56 * c); }, 0, 1, 2, 3 }, + {2, 5, [](auto r, auto c) {return T(r == c ? 1.0 : 0.0); }, 1, 1, 1, 1}, }; for (const auto& test_case : test_cases) test_dynamic_block_reductions(test_case); - - }; diff --git a/test/unit/block/operators.cpp b/test/unit/block/operators.cpp index 62ea4a7..64b91a5 100644 --- a/test/unit/block/operators.cpp +++ b/test/unit/block/operators.cpp @@ -12,7 +12,7 @@ template void test_block_matrix_operations(rotgen::tests::matrix_block_test_case const& matrix_construct, - auto b_init_fn, auto self_ops) + auto b_init_fn, auto ops, auto self_ops) { using EigenMatrix = Eigen::Matrix; @@ -39,12 +39,12 @@ void test_block_matrix_operations(rotgen::tests::matrix_block_test_case void test_block_scalar_operations(rotgen::tests::matrix_block_test_case const& matrix_construct, - auto scalar, auto self_ops) + auto scalar, auto ops, auto self_ops) { using EigenMatrix = Eigen::Matrix; @@ -81,7 +81,12 @@ void test_block_scalar_operations(rotgen::tests::matrix_block_test_case auto ref_b_block = ref_b.block(b_matrix_construct.i0, b_matrix_construct.j0, b_matrix_construct.ni, b_matrix_construct.nj); - // a * b + auto a_b_product_original = a_block * b_block; + auto a_b_product_ref = ref_a_block * ref_b_block; + + for (std::size_t r = 0; r < a_matrix_construct.ni; ++r) + for (std::size_t c = 0; c < a_matrix_construct.nj; ++c) + TTS_EQUAL(a_b_product_original (r, c), a_b_product_ref(r, c)); + a_block *= b_block; ref_a_block *= ref_b_block; @@ -183,19 +205,18 @@ TTS_CASE_TPL("Check block addition", rotgen::tests::types) ( tts::type< tts::types> ) { using mat_t = rotgen::matrix; - //auto op = [](auto a, auto b) { return a + b; }; + auto op = [](auto a, auto b) { return a + b; }; auto s_op = [](auto& a, auto b) { return a += b; }; - //test_block_matrix_operations({1 , 1, init_a, 0, 0, 1, 1}, init_b, op, s_op); - test_block_matrix_operations({1 , 1, init_a, 0, 0, 1, 1}, init_b, s_op); - test_block_matrix_operations({13 , 15, init_a, 1, 2, 3, 4}, init_b, s_op); - test_block_matrix_operations({5 , 9, init_a, 2, 2, 2, 2}, init_b, s_op); - test_block_matrix_operations({15 , 15, init_a, 3, 4, 5, 5}, init_b, s_op); - test_block_matrix_operations({5 , 5, init_b, 1, 0, 3, 2}, init_a, s_op); - test_block_matrix_operations({10, 1, init_a, 0, 0, 5, 1}, init_b, s_op); - test_block_matrix_operations({1 , 10, init_a, 0, 0, 1, 5}, init_b, s_op); - test_block_matrix_operations({21 , 5, init_0, 4, 4, 10, 1}, init_b, s_op); - test_block_matrix_operations({11 , 7, init_a, 2, 0, 7, 5}, init_0, s_op); + test_block_matrix_operations({1 , 1, init_a, 0, 0, 1, 1}, init_b, op, s_op); + test_block_matrix_operations({13 , 15, init_a, 1, 2, 3, 4}, init_b, op, s_op); + test_block_matrix_operations({5 , 9, init_a, 2, 2, 2, 2}, init_b, op, s_op); + test_block_matrix_operations({15 , 15, init_a, 3, 4, 5, 5}, init_b, op, s_op); + test_block_matrix_operations({5 , 5, init_b, 1, 0, 3, 2}, init_a, op, s_op); + test_block_matrix_operations({10, 1, init_a, 0, 0, 5, 1}, init_b, op, s_op); + test_block_matrix_operations({1 , 10, init_a, 0, 0, 1, 5}, init_b, op, s_op); + test_block_matrix_operations({21 , 5, init_0, 4, 4, 10, 1}, init_b, op, s_op); + test_block_matrix_operations({11 , 7, init_a, 2, 0, 7, 5}, init_0, op, s_op); }; @@ -203,19 +224,19 @@ TTS_CASE_TPL("Check block subtraction", rotgen::tests::types) ( tts::type< tts::types> ) { using mat_t = rotgen::matrix; - //auto op = [](auto a, auto b) { return a - b; }; + auto op = [](auto a, auto b) { return a - b; }; auto s_op = [](auto& a, auto b) { return a -= b; }; - //test_block_matrix_operations({1 , 1, init_a, 0, 0, 1, 1}, init_b, op, s_op); - test_block_matrix_operations({1 , 1, init_a, 0, 0, 1, 1}, init_b, s_op); - test_block_matrix_operations({13 , 15, init_a, 1, 2, 3, 4}, init_b, s_op); - test_block_matrix_operations({5 , 9, init_a, 2, 2, 2, 2}, init_b, s_op); - test_block_matrix_operations({15 , 15, init_a, 3, 4, 5, 5}, init_b, s_op); - test_block_matrix_operations({5 , 5, init_b, 1, 0, 3, 2}, init_a, s_op); - test_block_matrix_operations({10, 1, init_a, 0, 0, 5, 1}, init_b, s_op); - test_block_matrix_operations({1 , 10, init_a, 0, 0, 1, 5}, init_b,s_op); - test_block_matrix_operations({21 , 5, init_0, 4, 4, 10, 1}, init_b, s_op); - test_block_matrix_operations({11 , 7, init_a, 2, 0, 7, 5}, init_0, s_op); + test_block_matrix_operations({1 , 1, init_a, 0, 0, 1, 1}, init_b, op, s_op); + test_block_matrix_operations({1 , 1, init_a, 0, 0, 1, 1}, init_b, op, s_op); + test_block_matrix_operations({13 , 15, init_a, 1, 2, 3, 4}, init_b, op, s_op); + test_block_matrix_operations({5 , 9, init_a, 2, 2, 2, 2}, init_b, op, s_op); + test_block_matrix_operations({15 , 15, init_a, 3, 4, 5, 5}, init_b, op, s_op); + test_block_matrix_operations({5 , 5, init_b, 1, 0, 3, 2}, init_a, op, s_op); + test_block_matrix_operations({10, 1, init_a, 0, 0, 5, 1}, init_b, op, s_op); + test_block_matrix_operations({1 , 10, init_a, 0, 0, 1, 5}, init_b,op, s_op); + test_block_matrix_operations({21 , 5, init_0, 4, 4, 10, 1}, init_b, op, s_op); + test_block_matrix_operations({11 , 7, init_a, 2, 0, 7, 5}, init_0, op, s_op); }; @@ -258,17 +279,16 @@ TTS_CASE_TPL("Check block division with scalar", rotgen::tests::types) ( tts::type< tts::types> ) { using mat_t = rotgen::matrix; - //auto op = [](auto a, auto b) { return a / b; }; + auto op = [](auto a, auto b) { return a / b; }; auto s_op = [](auto& a, auto b) { return a /= b; }; - //test_block_scalar_operations({1 , 1, init_a, 0, 0, 1, 1}, T{ 3.5}, op, s_op); - test_block_scalar_operations({1 , 1, init_a, 0, 0, 1, 1}, T{ 3.5}, s_op); - test_block_scalar_operations({13 , 15, init_a, 1, 2, 3, 4}, T{-2.5}, s_op); - test_block_scalar_operations({5 , 9, init_a, 2, 2, 2, 2}, T{ 42. }, s_op); - test_block_scalar_operations({15 , 15, init_a, 3, 4, 5, 5}, T{-5. }, s_op); - test_block_scalar_operations({5 , 5, init_b, 1, 0, 3, 2}, T{ 1. }, s_op); - test_block_scalar_operations({10, 1, init_a, 0, 0, 5, 1}, T{ 0. }, s_op); - test_block_scalar_operations({1 , 10, init_a, 0, 0, 1, 5}, T{ 6. },s_op); - test_block_scalar_operations({21 , 5, init_0, 4, 4, 10, 1}, T{ 10.}, s_op); - test_block_scalar_operations({11 , 7, init_a, 2, 0, 7, 5}, T{-0.5}, s_op); + test_block_scalar_operations({1 , 1, init_a, 0, 0, 1, 1}, T{ 3.5}, op, s_op); + test_block_scalar_operations({13 , 15, init_a, 1, 2, 3, 4}, T{-2.5}, op, s_op); + test_block_scalar_operations({5 , 9, init_a, 2, 2, 2, 2}, T{ 42. }, op, s_op); + test_block_scalar_operations({15 , 15, init_a, 3, 4, 5, 5}, T{-5. }, op, s_op); + test_block_scalar_operations({5 , 5, init_b, 1, 0, 3, 2}, T{ 1. }, op, s_op); + test_block_scalar_operations({10, 1, init_a, 0, 0, 5, 1}, T{ 0. }, op, s_op); + test_block_scalar_operations({1 , 10, init_a, 0, 0, 1, 5}, T{ 6. }, op, s_op); + test_block_scalar_operations({21 , 5, init_0, 4, 4, 10, 1}, T{ 10.}, op, s_op); + test_block_scalar_operations({11 , 7, init_a, 2, 0, 7, 5}, T{-0.5}, op, s_op); };