diff --git a/include/rotgen/impl/matrix_impl64.hpp b/include/rotgen/impl/matrix_impl64.hpp index 07a3fce..0764624 100644 --- a/include/rotgen/impl/matrix_impl64.hpp +++ b/include/rotgen/impl/matrix_impl64.hpp @@ -36,6 +36,13 @@ namespace rotgen void resize(std::size_t new_rows, std::size_t new_cols); void conservativeResize(std::size_t new_rows, std::size_t new_cols); + matrix_impl64 transpose() const; + matrix_impl64 conjugate() const; + matrix_impl64 adjoint() const; + + void transposeInPlace(); + void adjointInPlace(); + double& operator()(std::size_t i, std::size_t j); double const& operator()(std::size_t i, std::size_t j) const; diff --git a/include/rotgen/matrix.hpp b/include/rotgen/matrix.hpp index 9c78d00..febe15b 100644 --- a/include/rotgen/matrix.hpp +++ b/include/rotgen/matrix.hpp @@ -60,6 +60,25 @@ namespace rotgen parent::conservativeResize(new_rows, new_cols); } + matrix transpose() const + { + return matrix(static_cast(*this).transpose()); + } + + matrix conjugate() const + { + return matrix(static_cast(*this).conjugate()); + } + + matrix adjoint() const + { + return matrix(static_cast(*this).adjoint()); + } + + void transposeInPlace() { parent::transposeInPlace(); } + + void adjointInPlace() { parent::adjointInPlace(); } + friend bool operator==(matrix const& lhs, matrix const& rhs) { return static_cast(lhs) == static_cast(rhs); diff --git a/src/matrix_impl64.cpp b/src/matrix_impl64.cpp index 658b22c..40d616b 100644 --- a/src/matrix_impl64.cpp +++ b/src/matrix_impl64.cpp @@ -76,6 +76,36 @@ namespace rotgen const double* matrix_impl64::data() const { return storage_->data.data(); } + matrix_impl64 matrix_impl64::transpose() const + { + matrix_impl64 result(*this); + result.storage_->data = storage_->data.transpose(); + return result; + } + + matrix_impl64 matrix_impl64::conjugate() const + { + matrix_impl64 result(*this); + result.storage_->data = storage_->data.conjugate(); + return result; + } + + matrix_impl64 matrix_impl64::adjoint() const + { + matrix_impl64 result(*this); + result.storage_->data = storage_->data.adjoint(); + return result; + } + + void matrix_impl64::transposeInPlace() + { + storage_->data.transposeInPlace(); + } + void matrix_impl64::adjointInPlace() + { + storage_->data.adjointInPlace(); + } + //================================================================================================== // Operators //================================================================================================== diff --git a/test/basic/arithmetic_functions.cpp b/test/basic/arithmetic_functions.cpp new file mode 100644 index 0000000..8d8ad26 --- /dev/null +++ b/test/basic/arithmetic_functions.cpp @@ -0,0 +1,63 @@ +//================================================================================================== +/* + ROTGEN - Runtime Overlay for Eigen + Copyright : CODE RECKONS + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#define TTS_MAIN +#include +#include "tts.hpp" + +template +struct MatrixDescriptor +{ + std::size_t rows, cols; + std::function init_fn; +}; + +template +void test_matrix_unary_ops(std::size_t rows, std::size_t cols, + const std::function &init_fn) +{ + + MatrixType original(rows, cols); + MatrixType transposed_matrix(cols, rows); + + for (std::size_t r = 0; r < rows; ++r) + for (std::size_t c = 0; c < cols; ++c) + init_fn(original, r, c); + + for (std::size_t r = 0; r < rows; ++r) + for (std::size_t c = 0; c < cols; ++c) + transposed_matrix(c, r) = original(r, c); + + TTS_EQUAL(original.transpose(), transposed_matrix); + TTS_EQUAL(original.conjugate(), original); + TTS_EQUAL(original.adjoint(), transposed_matrix); + + original.transposeInPlace(); + TTS_EQUAL(original, transposed_matrix); + + original.transposeInPlace(); + original.adjointInPlace(); + TTS_EQUAL(original, transposed_matrix); +} + +TTS_CASE("Matrix unary operations: transpose, adjoint, conjugate") +{ + std::vector>> test_matrices = { + {3, 3, [](auto& m, std::size_t r, std::size_t c) { m(r, c) = r + 3 * c - 2.5; }}, + {4, 9, [](auto& m, std::size_t r, std::size_t c) { m(r, c) = r*r + 3.12 * c + 6.87; }}, + {2, 7, [](auto& m, std::size_t r, std::size_t c) { m(r, c) = 1.1 * (r - c); }}, + {1, 5, [](auto &m, std::size_t r, std::size_t c) { m(r, c) = 9.99; }}, + {4, 2, [](auto& m, std::size_t r, std::size_t c) { m(r, c) = 0.0; }}, + {3, 3, [](auto& m, std::size_t r, std::size_t c) { m(r, c) = (r == c) ? 1.0 : 0.0; }}, + {2, 2, [](auto& m, std::size_t r, std::size_t c) { m(r, c) = (r + c) * 1e-10; }}, + {2, 2, [](auto& m, std::size_t r, std::size_t c) { m(r, c) = (r + 1) * 1e+10; }}, + }; + + for (auto const &desc : test_matrices) + test_matrix_unary_ops>(desc.rows, desc.cols, desc.init_fn); +}; +