//================================================================================================== /* ROTGEN - Runtime Overlay for Eigen Copyright : CODE RECKONS SPDX-License-Identifier: BSL-1.0 */ //================================================================================================== #define TTS_MAIN #include #include "tts.hpp" template void test_matrix_scalar_multiplication(std::size_t rows, std::size_t cols, double scalar, const std::function& init_fn) { MatrixType a(rows, cols); MatrixType ref(rows, cols); for (std::size_t r = 0; r < rows; ++r) { for (std::size_t c = 0; c < cols; ++c) { init_fn(a, r, c); ref(r, c) = a(r, c) * scalar; } } TTS_EQUAL(a * scalar, ref); TTS_EQUAL(scalar * a, ref); a *= scalar; TTS_EQUAL(a, ref); } template void test_matrix_scalar_division(std::size_t rows, std::size_t cols, double scalar, const std::function& init_fn) { MatrixType a(rows, cols); MatrixType ref(rows, cols); for (std::size_t r = 0; r < rows; ++r) { for (std::size_t c = 0; c < cols; ++c) { init_fn(a, r, c); ref(r, c) = a(r, c) / scalar; } } TTS_EQUAL(a / scalar, ref); a /= scalar; TTS_EQUAL(a, ref); } template void test_matrix_multiplication(std::size_t n, std::size_t m, std::size_t p, InitA&& a_init_fn, InitB&& b_init_fn) { MatrixType a(n, m); MatrixType b(m, p); MatrixType ref(n, p); for (std::size_t r = 0; r < n; ++r) for (std::size_t c = 0; c < m; ++c) a_init_fn(a, r, c); for (std::size_t r = 0; r < m; ++r) for (std::size_t c = 0; c < p; ++c) b_init_fn(b, r, c); for (std::size_t i = 0; i < n; ++i) for (std::size_t j = 0; j < p; ++j) { ref(i, j) = 0; for (std::size_t k = 0; k < m; ++k) ref(i, j) += a(i, k) * b(k, j); } TTS_EQUAL(a * b, ref); a *= b; TTS_EQUAL(a, ref); } template void test_matrix_addition(std::size_t rows, std::size_t cols, InitA&& a_init_fn, InitB&& b_init_fn) { MatrixType a(rows, cols); MatrixType b(rows, cols); MatrixType ref(rows, cols); for (std::size_t r = 0; r < rows; ++r) { for (std::size_t c = 0; c < cols; ++c) { a_init_fn(a, r, c); b_init_fn(b, r, c); ref(r, c) = a(r,c) + b(r,c); } } TTS_EQUAL(a + b, ref); TTS_EQUAL(b + a, ref); a += b; TTS_EQUAL(a, ref); } template void test_matrix_substraction(std::size_t rows, std::size_t cols, InitA&& a_init_fn, InitB&& b_init_fn) { MatrixType a(rows, cols); MatrixType b(rows, cols); MatrixType ref(rows, cols); MatrixType a_minus_ref(rows, cols); MatrixType b_minus_ref(rows, cols); for (std::size_t r = 0; r < rows; ++r) { for (std::size_t c = 0; c < cols; ++c) { a_init_fn(a, r, c); b_init_fn(b, r, c); ref(r, c) = a(r,c) - b(r,c); a_minus_ref(r, c) = -a(r,c); b_minus_ref(r, c) = -b(r,c); } } MatrixType a_unary = -a; MatrixType b_unary = -b; TTS_EQUAL(a - b, ref); TTS_EQUAL(a_unary, a_minus_ref); TTS_EQUAL(-a, a_minus_ref); TTS_EQUAL(-b, b_minus_ref); TTS_EQUAL(-(-a), a); TTS_EQUAL(-(-b), b); a -= b; TTS_EQUAL(a, ref); }; TTS_CASE("Check matrix * scalar and scalar * matrix with default values") { test_matrix_scalar_multiplication>(2, 2, 10.5, [](auto& a, std::size_t r, std::size_t c) { a(r, c) = (1 + c) + 10 * (1 + r); }); }; TTS_CASE("Check matrix * scalar with zero scalar multiplication") { test_matrix_scalar_multiplication>(3, 2, 0.0, [](auto& a, std::size_t r, std::size_t c) { a(r, c) = 5 * c - r; }); }; TTS_CASE("Check matrix * scalar with one scalar multiplication") { test_matrix_scalar_multiplication>(3, 2, 1, [](auto& a, std::size_t r, std::size_t c) { a(r, c) = 3.3*r - 6; }); }; TTS_CASE("Check matrix - scalar with negative scalar multiplication") { test_matrix_scalar_multiplication>(3, 2, -36.2, [](auto& a, std::size_t r, std::size_t c) { a(r, c) = r * r - c + 3.9; }); }; TTS_CASE("Check static matrix - scalar with float scalar multiplication") { test_matrix_scalar_multiplication>(3, 4, 5.6, [](auto& a, std::size_t r, std::size_t c) { a(r, c) = 1.2*r+4.5*c-6; }); }; TTS_CASE("Check matrix / scalar with default values") { test_matrix_scalar_division>(6, 7, 10.5, [](auto& a, std::size_t r, std::size_t c) { a(r, c) = (1 + c) + 10 * (1 + r); }); }; TTS_CASE("Check matrix * scalar with one scalar multiplication") { test_matrix_scalar_division>(3, 2, 1, [](auto& a, std::size_t r, std::size_t c) { a(r, c) = 3.3*r - 4.5*c + 1.1; }); }; TTS_CASE("Check matrix - scalar with negative scalar multiplication") { test_matrix_scalar_division>(2, 7, -36.2, [](auto& a, std::size_t r, std::size_t c) { a(r, c) = 3.4 * r * r - c + 3.9; }); }; TTS_CASE("Check static matrix - scalar with float scalar multiplication") { test_matrix_scalar_division>(3, 4, 5.6, [](auto& a, std::size_t r, std::size_t c) { a(r, c) = 1.2*r+4.5*c-6; }); }; TTS_CASE("Matrix multiplication") { test_matrix_multiplication>(2, 3, 4, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = r + c; }, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = r * c; }); }; TTS_CASE("Matrix multiplication with floats") { test_matrix_multiplication>(3, 6, 3, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = 3.4 * r + 5.4 * c*c - 5.2; }, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = r * c + 5 * r - 3.33; }); }; TTS_CASE("Matrix multiplication with zero matrix") { test_matrix_multiplication>(3, 4, 5, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = (-4) * r + 3*c; }, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = 0; }); }; TTS_CASE("Matrix multiplication with itself") { test_matrix_multiplication>(3, 3, 3, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = 3*r - 7*r*c + 14.4; }, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = 3*r - 7*r*c + 14.4; }); }; TTS_CASE("Matrix multiplication with identity matrix") { test_matrix_multiplication>(4, 4, 4, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = r*r*r - c + 12; }, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = (r==c); }); }; TTS_CASE("Matrix addition") { test_matrix_addition>(3, 4, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = 9.9*r*r*r - 6*c -12; }, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = 3.1*r + 4.2*c - 12.3; }); }; TTS_CASE("Matrix addition with zero matrix") { test_matrix_addition>(3, 4, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = 2*r*r + c + 3.4; }, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = 0; }); }; TTS_CASE("Matrix substraction") { test_matrix_substraction>(3, 4, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = 6.78*r - 5.2*c - 0.01; }, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = 3.1*r + 33.456*c*c; }); }; TTS_CASE("Matrix substraction with zero matrix") { test_matrix_substraction>(3, 4, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = r + c*c*c*c - 56.6; }, [](auto& mat, std::size_t r, std::size_t c) { mat(r, c) = 0; }); };