Feat/block implementation

See merge request oss/rotgen!10
This commit is contained in:
Karen Kaspar 2025-06-23 15:22:11 +02:00 committed by Joel Falcou
parent b868398a77
commit 3d23a07e90
6 changed files with 406 additions and 287 deletions

View file

@ -11,6 +11,37 @@
#include <rotgen/extract.hpp>
#include <Eigen/Dense>
template<typename Type1, typename Type2>
void test_comparison(const Type1& t1, const Type2& t2)
{
TTS_EQUAL(static_cast<std::ptrdiff_t>(t1.rows()), static_cast<std::ptrdiff_t>(t2.rows()));
TTS_EQUAL(static_cast<std::ptrdiff_t>(t1.cols()), static_cast<std::ptrdiff_t>(t2.cols()));
for (std::size_t r = 0; r < static_cast<std::size_t>(t1.rows()); ++r)
for (std::size_t c = 0; c < static_cast<std::size_t>(t1.cols()); ++c)
TTS_EQUAL(t1(r, c), t2(r, c));
}
template<typename Matrix1, typename Matrix2, typename Block1, typename Block2>
void test_block_unary_ops(const Matrix1& original_matrix, const Matrix2& ref_matrix,
Block1 original_block, Block2 ref_block)
{
test_comparison(original_block.transpose(), ref_block.transpose());
test_comparison(original_block.conjugate(), ref_block.conjugate());
test_comparison(original_block.adjoint(), ref_block.adjoint());
if (original_block.rows() == original_block.cols()) {
original_block.transposeInPlace();
ref_block.transposeInPlace();
test_comparison(original_block, ref_block);
test_comparison(original_matrix, ref_matrix);
original_block.adjointInPlace();
ref_block.adjointInPlace();
test_comparison(original_block, ref_block);
test_comparison(original_matrix, ref_matrix);
}
}
template<typename Block1, typename Block2>
void compare_reductions(const Block1& block, const Block2& ref)
{
@ -68,8 +99,11 @@ void test_dynamic_block_reductions(rotgen::tests::matrix_block_test_case<MatrixT
ref.bottomRows(matrix_construct.ni) },
};
for (const auto& [original_block, ref_block] : test_cases)
for (const auto& [original_block, ref_block] : test_cases) {
compare_reductions(original_block, ref_block);
test_block_unary_ops(original, ref, original_block, ref_block);
}
}
TTS_CASE_TPL("Test dynamic block reductions", rotgen::tests::types)
@ -79,19 +113,19 @@ TTS_CASE_TPL("Test dynamic block reductions", rotgen::tests::types)
std::vector<rotgen::tests::matrix_block_test_case<mat_t>> test_cases =
{
{6, 5, [](auto r, auto c) {return r + c; }, 1, 2, 3, 2},
{9, 11, [](auto r, auto c) {return r + c; }, 0, 1, 4, 9},
{3, 3, [](auto , auto ) {return 0.0; }, 1, 1, 1, 1},
{1, 4, [](auto r, auto c) {return -r -c*c - 1234; }, 0, 0, 1, 1},
{4, 1, [](auto , auto ) {return 7.0; }, 2, 0, 2, 1},
{1, 1, [](auto , auto ) {return 42.0; }, 0, 0, 1, 1},
{12, 13, [](auto r, auto c) {return std::sin(r + c); }, 2, 3, 4, 5 },
{4, 9, [](auto r, auto c) {return -1.5 * r + 2.56 * c; }, 0, 1, 2, 3 },
{2, 5, [](auto r, auto c) {return (r == c ? 1.0 : 0.0); }, 1, 1, 1, 1},
{6, 5, [](auto r, auto c) { return T(r + c); }, 1, 2, 3, 2},
{9, 11, [](auto r, auto c) {return T(r + c); }, 0, 1, 4, 9},
{3, 3, [](auto , auto ) {return T(0.0); }, 1, 1, 1, 1},
{1, 4, [](auto r, auto c) {return T(-r -c*c - 1234); }, 0, 0, 1, 1},
{9, 9, [](auto r, auto c) {return T(-r + 2*c); }, 0, 1, 3, 3},
{11, 13, [](auto r, auto c) {return T(std::tan(r+c)); }, 1, 1, 2, 2},
{4, 1, [](auto , auto ) {return T(7.0); }, 2, 0, 2, 1},
{1, 1, [](auto , auto ) {return T(42.0); }, 0, 0, 1, 1},
{12, 13, [](auto r, auto c) {return T(std::sin(r + c)); }, 2, 3, 4, 5 },
{4, 9, [](auto r, auto c) {return T(-1.5 * r + 2.56 * c); }, 0, 1, 2, 3 },
{2, 5, [](auto r, auto c) {return T(r == c ? 1.0 : 0.0); }, 1, 1, 1, 1},
};
for (const auto& test_case : test_cases)
test_dynamic_block_reductions<mat_t, T>(test_case);
};