//================================================================================================== /* ROTGEN - Runtime Overlay for Eigen Copyright : CODE RECKONS SPDX-License-Identifier: BSL-1.0 */ //================================================================================================== #include "unit/tests.hpp" #include template void test_matrix_operations(rotgen::Index rows, rotgen::Index cols, auto a_init_fn, auto b_init_fn, auto ops, auto self_ops) { MatrixType a(rows, cols); MatrixType b(rows, cols); MatrixType ref(rows, cols); for (rotgen::Index r = 0; r < rows; ++r) { for (rotgen::Index c = 0; c < cols; ++c) { a(r,c) = a_init_fn(r, c); b(r,c) = b_init_fn(r, c); ref(r, c) = ops(a(r,c),b(r,c)); } } TTS_EQUAL(ops(a, b), ref); self_ops(a,b); TTS_EQUAL(a, ref); TTS_EXPECT(verify_rotgen_reentrance(ops(a, b))); TTS_EXPECT(verify_rotgen_reentrance(self_ops(a, b))); } template void test_scalar_operations(rotgen::Index rows, rotgen::Index cols, auto a_init_fn, auto s, auto ops, auto self_ops) { MatrixType a(rows, cols); MatrixType ref(rows, cols); for (rotgen::Index r = 0; r < rows; ++r) { for (rotgen::Index c = 0; c < cols; ++c) { a(r,c) = a_init_fn(r, c); ref(r, c) = ops(a(r,c),s); } } TTS_EQUAL(ops(a, s), ref); self_ops(a,s); TTS_EQUAL(a, ref); TTS_EXPECT(verify_rotgen_reentrance(ops(a, s))); TTS_EXPECT(verify_rotgen_reentrance(self_ops(a, s))); } template void test_scalar_multiplications(rotgen::Index rows, rotgen::Index cols, auto fn, auto s) { MatrixType a(rows, cols); MatrixType ref(rows, cols); for (rotgen::Index r = 0; r < rows; ++r) { for (rotgen::Index c = 0; c < cols; ++c) { a(r,c) = fn(r, c); ref(r, c) = a(r,c) * s; } } TTS_EQUAL(a * s, ref); TTS_EQUAL(s * a, ref); a *= s; TTS_EQUAL(a, ref); TTS_EXPECT(verify_rotgen_reentrance(a*s)); TTS_EXPECT(verify_rotgen_reentrance(s*a)); TTS_EXPECT(verify_rotgen_reentrance(a*=s)); } template void test_matrix_multiplication(rotgen::Index rows, rotgen::Index cols, auto a_init_fn, auto b_init_fn) { MatrixType a(rows, cols); MatrixType b(cols, rows); MatrixType ref(rows, rows); for (rotgen::Index r = 0; r < a.rows(); ++r) for (rotgen::Index c = 0; c < a.cols(); ++c) a(r,c) = a_init_fn(r, c); for (rotgen::Index r = 0; r < b.rows(); ++r) for (rotgen::Index c = 0; c < b.cols(); ++c) b(r,c) = b_init_fn(r, c); for (rotgen::Index i = 0; i < a.rows(); ++i) { for (rotgen::Index j = 0; j < b.cols(); ++j) { ref(i, j) = 0; for (rotgen::Index k = 0; k < a.cols(); ++k) ref(i, j) += a(i, k) * b(k, j); } } TTS_EQUAL(a * b, ref); TTS_EXPECT(verify_rotgen_reentrance(a*b)); } // Basic initializers inline constexpr auto init_a = [](auto r, auto c) { return 9.9*r*r*r - 6*c -12; }; inline constexpr auto init_b = [](auto r, auto c) { return 3.1*r + 4.2*c - 12.3; }; inline constexpr auto init_0 = [](auto , auto ) { return 0; }; TTS_CASE_TPL("Check matrix addition", rotgen::tests::types) ( tts::type< tts::types> ) { using mat_t = rotgen::matrix; auto op = [](auto a, auto b) { return a + b; }; auto s_op = [](auto& a, auto b) { return a += b; }; test_matrix_operations(1 , 1, init_a, init_b, op, s_op); test_matrix_operations(3 , 5, init_a, init_b, op, s_op); test_matrix_operations(5 , 3, init_a, init_b, op, s_op); test_matrix_operations(5 , 5, init_a, init_b, op, s_op); test_matrix_operations(5 , 5, init_b, init_a, op, s_op); test_matrix_operations(10, 1, init_a, init_b, op, s_op); test_matrix_operations(1 ,10, init_a, init_b, op, s_op); test_matrix_operations(5 , 5, init_0, init_b, op, s_op); test_matrix_operations(5 , 5, init_a, init_0, op, s_op); }; TTS_CASE_TPL("Check matrix substraction", rotgen::tests::types) ( tts::type< tts::types> ) { using mat_t = rotgen::matrix; auto op = [](auto a, auto b) { return a - b; }; auto s_op = [](auto& a, auto b) { return a -= b; }; test_matrix_operations(1 , 1, init_a, init_b, op, s_op); test_matrix_operations(3 , 5, init_a, init_b, op, s_op); test_matrix_operations(5 , 3, init_a, init_b, op, s_op); test_matrix_operations(5 , 5, init_a, init_b, op, s_op); test_matrix_operations(5 , 5, init_b, init_a, op, s_op); test_matrix_operations(10, 1, init_a, init_b, op, s_op); test_matrix_operations(1 ,10, init_a, init_b, op, s_op); test_matrix_operations(5 , 5, init_0, init_b, op, s_op); test_matrix_operations(5 , 5, init_a, init_0, op, s_op); }; TTS_CASE_TPL("Check matrix multiplications", rotgen::tests::types) ( tts::type< tts::types> ) { using mat_t = rotgen::matrix; auto init_id = [](auto r, auto c) { return r == c ? 1 : 0; }; test_matrix_multiplication(1 , 1, init_a , init_b ); test_matrix_multiplication(3 , 5, init_a , init_b ); test_matrix_multiplication(5 , 3, init_a , init_b ); test_matrix_multiplication(5 , 5, init_a , init_b ); test_matrix_multiplication(5 , 5, init_b , init_a ); test_matrix_multiplication(5 , 5, init_a , init_a ); test_matrix_multiplication(5 , 5, init_a , init_id); test_matrix_multiplication(5 , 5, init_id, init_a ); test_matrix_multiplication(10, 1, init_a , init_b ); test_matrix_multiplication(1 ,10, init_a , init_b ); test_matrix_multiplication(5 , 5, init_0 , init_b ); test_matrix_multiplication(5 , 5, init_a , init_0 ); }; TTS_CASE_TPL("Check matrix multiplication with scalar", rotgen::tests::types) ( tts::type< tts::types> ) { using mat_t = rotgen::matrix; test_scalar_multiplications(1 , 1, init_a, T{ 3.5}); test_scalar_multiplications(3 , 5, init_a, T{-2.5}); test_scalar_multiplications(5 , 3, init_a, T{ 4. }); test_scalar_multiplications(5 , 5, init_a, T{-5. }); test_scalar_multiplications(5 , 5, init_a, T{ 1. }); test_scalar_multiplications(5 , 5, init_a, T{ 6. }); test_scalar_multiplications(10, 1, init_a, T{ 10.}); test_scalar_multiplications(1 ,10, init_a, T{-0.5}); }; TTS_CASE_TPL("Check matrix division with scalar", rotgen::tests::types) ( tts::type< tts::types> ) { using mat_t = rotgen::matrix; auto op = [](auto a, auto b) { return a / b; }; auto s_op = [](auto& a, auto b) { return a /= b; }; test_scalar_operations(1 , 1, init_a, T{ 3.5}, op, s_op); test_scalar_operations(3 , 5, init_a, T{-2.5}, op, s_op); test_scalar_operations(5 , 3, init_a, T{ 4. }, op, s_op); test_scalar_operations(5 , 5, init_a, T{-5. }, op, s_op); test_scalar_operations(5 , 5, init_a, T{ 1. }, op, s_op); test_scalar_operations(5 , 5, init_a, T{ 6. }, op, s_op); test_scalar_operations(10, 1, init_a, T{ 10.}, op, s_op); test_scalar_operations(1 ,10, init_a, T{-0.5}, op, s_op); };