//================================================================================================== /* ROTGEN - Runtime Overlay for Eigen Copyright : CODE RECKONS SPDX-License-Identifier: BSL-1.0 */ //================================================================================================== #include #include "unit/tests.hpp" template void test_matrix_operations(rotgen::Index rows, rotgen::Index cols, auto a_init_fn, auto b_init_fn, auto ops, auto self_ops) { MatrixType a(rows, cols); MatrixType b(rows, cols); MatrixType ref(rows, cols); for (rotgen::Index r = 0; r < rows; ++r) { for (rotgen::Index c = 0; c < cols; ++c) { a(r, c) = a_init_fn(r, c); b(r, c) = b_init_fn(r, c); ref(r, c) = ops(a(r, c), b(r, c)); } } TTS_EQUAL(ops(a, b), ref); self_ops(a, b); TTS_EQUAL(a, ref); TTS_EXPECT(verify_rotgen_reentrance(ops(a, b))); TTS_EXPECT(verify_rotgen_reentrance(self_ops(a, b))); } template void test_scalar_operations(rotgen::Index rows, rotgen::Index cols, auto a_init_fn, auto s, auto ops, auto self_ops) { MatrixType a(rows, cols); MatrixType ref(rows, cols); for (rotgen::Index r = 0; r < rows; ++r) { for (rotgen::Index c = 0; c < cols; ++c) { a(r, c) = a_init_fn(r, c); ref(r, c) = ops(a(r, c), s); } } TTS_EQUAL(ops(a, s), ref); self_ops(a, s); TTS_EQUAL(a, ref); TTS_EXPECT(verify_rotgen_reentrance(ops(a, s))); TTS_EXPECT(verify_rotgen_reentrance(self_ops(a, s))); } template void test_scalar_multiplications(rotgen::Index rows, rotgen::Index cols, auto fn, auto s) { MatrixType a(rows, cols); MatrixType ref(rows, cols); for (rotgen::Index r = 0; r < rows; ++r) { for (rotgen::Index c = 0; c < cols; ++c) { a(r, c) = fn(r, c); ref(r, c) = a(r, c) * s; } } TTS_EQUAL(a * s, ref); TTS_EQUAL(s * a, ref); a *= s; TTS_EQUAL(a, ref); TTS_EXPECT(verify_rotgen_reentrance(a * s)); TTS_EXPECT(verify_rotgen_reentrance(s * a)); TTS_EXPECT(verify_rotgen_reentrance(a *= s)); } template void test_matrix_multiplication(rotgen::Index rows, rotgen::Index cols, auto a_init_fn, auto b_init_fn) { MatrixType a(rows, cols); MatrixType b(cols, rows); MatrixType ref(rows, rows); for (rotgen::Index r = 0; r < a.rows(); ++r) for (rotgen::Index c = 0; c < a.cols(); ++c) a(r, c) = a_init_fn(r, c); for (rotgen::Index r = 0; r < b.rows(); ++r) for (rotgen::Index c = 0; c < b.cols(); ++c) b(r, c) = b_init_fn(r, c); for (rotgen::Index i = 0; i < a.rows(); ++i) { for (rotgen::Index j = 0; j < b.cols(); ++j) { ref(i, j) = 0; for (rotgen::Index k = 0; k < a.cols(); ++k) ref(i, j) += a(i, k) * b(k, j); } } TTS_EQUAL(a * b, ref); TTS_EXPECT(verify_rotgen_reentrance(a * b)); } // Basic initializers inline constexpr auto init_a = [](auto r, auto c) { return 9.9 * r * r * r - 6 * c - 12; }; inline constexpr auto init_b = [](auto r, auto c) { return 3.1 * r + 4.2 * c - 12.3; }; inline constexpr auto init_0 = [](auto, auto) { return 0; }; TTS_CASE_TPL("Check matrix addition", rotgen::tests::types)( tts::type>) { using mat_t = rotgen::matrix; auto op = [](auto a, auto b) { return a + b; }; auto s_op = [](auto& a, auto b) { return a += b; }; test_matrix_operations(1, 1, init_a, init_b, op, s_op); test_matrix_operations(3, 5, init_a, init_b, op, s_op); test_matrix_operations(5, 3, init_a, init_b, op, s_op); test_matrix_operations(5, 5, init_a, init_b, op, s_op); test_matrix_operations(5, 5, init_b, init_a, op, s_op); test_matrix_operations(10, 1, init_a, init_b, op, s_op); test_matrix_operations(1, 10, init_a, init_b, op, s_op); test_matrix_operations(5, 5, init_0, init_b, op, s_op); test_matrix_operations(5, 5, init_a, init_0, op, s_op); }; TTS_CASE_TPL("Check matrix substraction", rotgen::tests::types)( tts::type>) { using mat_t = rotgen::matrix; auto op = [](auto a, auto b) { return a - b; }; auto s_op = [](auto& a, auto b) { return a -= b; }; test_matrix_operations(1, 1, init_a, init_b, op, s_op); test_matrix_operations(3, 5, init_a, init_b, op, s_op); test_matrix_operations(5, 3, init_a, init_b, op, s_op); test_matrix_operations(5, 5, init_a, init_b, op, s_op); test_matrix_operations(5, 5, init_b, init_a, op, s_op); test_matrix_operations(10, 1, init_a, init_b, op, s_op); test_matrix_operations(1, 10, init_a, init_b, op, s_op); test_matrix_operations(5, 5, init_0, init_b, op, s_op); test_matrix_operations(5, 5, init_a, init_0, op, s_op); }; TTS_CASE_TPL("Check matrix multiplications", rotgen::tests::types)( tts::type>) { using mat_t = rotgen::matrix; auto init_id = [](auto r, auto c) { return r == c ? 1 : 0; }; test_matrix_multiplication(1, 1, init_a, init_b); test_matrix_multiplication(3, 5, init_a, init_b); test_matrix_multiplication(5, 3, init_a, init_b); test_matrix_multiplication(5, 5, init_a, init_b); test_matrix_multiplication(5, 5, init_b, init_a); test_matrix_multiplication(5, 5, init_a, init_a); test_matrix_multiplication(5, 5, init_a, init_id); test_matrix_multiplication(5, 5, init_id, init_a); test_matrix_multiplication(10, 1, init_a, init_b); test_matrix_multiplication(1, 10, init_a, init_b); test_matrix_multiplication(5, 5, init_0, init_b); test_matrix_multiplication(5, 5, init_a, init_0); }; TTS_CASE_TPL("Check matrix multiplication with scalar", rotgen::tests::types)( tts::type>) { using mat_t = rotgen::matrix; test_scalar_multiplications(1, 1, init_a, T{3.5}); test_scalar_multiplications(3, 5, init_a, T{-2.5}); test_scalar_multiplications(5, 3, init_a, T{4.}); test_scalar_multiplications(5, 5, init_a, T{-5.}); test_scalar_multiplications(5, 5, init_a, T{1.}); test_scalar_multiplications(5, 5, init_a, T{6.}); test_scalar_multiplications(10, 1, init_a, T{10.}); test_scalar_multiplications(1, 10, init_a, T{-0.5}); }; TTS_CASE_TPL("Check matrix division with scalar", rotgen::tests::types)( tts::type>) { using mat_t = rotgen::matrix; auto op = [](auto a, auto b) { return a / b; }; auto s_op = [](auto& a, auto b) { return a /= b; }; test_scalar_operations(1, 1, init_a, T{3.5}, op, s_op); test_scalar_operations(3, 5, init_a, T{-2.5}, op, s_op); test_scalar_operations(5, 3, init_a, T{4.}, op, s_op); test_scalar_operations(5, 5, init_a, T{-5.}, op, s_op); test_scalar_operations(5, 5, init_a, T{1.}, op, s_op); test_scalar_operations(5, 5, init_a, T{6.}, op, s_op); test_scalar_operations(10, 1, init_a, T{10.}, op, s_op); test_scalar_operations(1, 10, init_a, T{-0.5}, op, s_op); };