rotgen/test/unit/common/arithmetic.hpp
Jules Pénuchot 648dd768ee Adding clang-format configuration file and formatting all source files
Co-authored-by: Jules Pénuchot <jules@penuchot.com>
Co-authored-by: Joel FALCOU <jfalcou@codereckons.com>

See merge request oss/rotgen!41
2025-10-14 16:19:03 +02:00

87 lines
2.5 KiB
C++

//==================================================================================================
/*
ROTGEN - Runtime Overlay for Eigen
Copyright : CODE RECKONS
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#include <rotgen/rotgen.hpp>
#include <Eigen/Dense>
#include <vector>
#include <tuple>
namespace rotgen::tests
{
template<typename T> void check_shape_functions(T original)
{
using mat_t =
matrix<typename T::value_type, Dynamic, Dynamic, T::storage_order>;
mat_t result(original.cols(), original.rows());
prepare([&](auto r, auto c) { return original(c, r); }, result);
TTS_EQUAL(transpose(original), result);
TTS_EQUAL(conjugate(original), original);
TTS_EQUAL(adjoint(original), result);
if constexpr (T::is_defined_static)
{
if constexpr (T::RowsAtCompileTime == T::ColsAtCompileTime)
{
mat_t ref = original;
transposeInPlace(original);
TTS_EQUAL(original, result);
adjointInPlace(original);
TTS_EQUAL(original, ref);
}
}
else
{
if (original.rows() == original.cols())
{
mat_t ref = original;
transposeInPlace(original);
TTS_EQUAL(original, result);
adjointInPlace(original);
TTS_EQUAL(original, ref);
}
}
if constexpr (!rotgen::use_expression_templates)
{
TTS_EXPECT(verify_rotgen_reentrance(original.transpose()));
TTS_EXPECT(verify_rotgen_reentrance(original.conjugate()));
TTS_EXPECT(verify_rotgen_reentrance(original.adjoint()));
}
}
template<typename T> void check_reduction_functions(T const& input)
{
using EigenMatrix =
Eigen::Matrix<typename T::value_type, Eigen::Dynamic, Eigen::Dynamic>;
EigenMatrix ref(input.rows(), input.cols());
prepare([&](auto r, auto c) { return input(r, c); }, ref);
TTS_ULP_EQUAL(sum(input), ref.sum(), 2);
TTS_ULP_EQUAL(prod(input), ref.prod(), 2);
TTS_ULP_EQUAL(mean(input), ref.mean(), 2);
TTS_EQUAL(trace(input), ref.trace());
TTS_EQUAL(minCoeff(input), ref.minCoeff());
TTS_EQUAL(maxCoeff(input), ref.maxCoeff());
{
int row, col, ref_row, ref_col;
TTS_EQUAL(minCoeff(input, &row, &col), ref.minCoeff(&ref_row, &ref_col));
TTS_EQUAL(row, ref_row);
TTS_EQUAL(col, ref_col);
TTS_EQUAL(maxCoeff(input, &row, &col), ref.maxCoeff(&ref_row, &ref_col));
TTS_EQUAL(row, ref_row);
TTS_EQUAL(col, ref_col);
}
}
}