diff --git a/include/rotgen/impl/matrix_impl64.hpp b/include/rotgen/impl/matrix_impl64.hpp index 1651f71..d2e8a1a 100644 --- a/include/rotgen/impl/matrix_impl64.hpp +++ b/include/rotgen/impl/matrix_impl64.hpp @@ -70,7 +70,12 @@ namespace rotgen const double* data() const; - private: + static matrix_impl64 Zero(std::size_t rows, std::size_t cols); + static matrix_impl64 Constant(std::size_t rows, std::size_t cols, double value); + static matrix_impl64 Random(std::size_t rows, std::size_t cols); + static matrix_impl64 Identity(std::size_t rows, std::size_t cols); + + private: struct payload; std::unique_ptr storage_; }; diff --git a/include/rotgen/matrix.hpp b/include/rotgen/matrix.hpp index febe15b..6c138d4 100644 --- a/include/rotgen/matrix.hpp +++ b/include/rotgen/matrix.hpp @@ -118,6 +118,58 @@ namespace rotgen static_cast(*this) /= rhs; return *this; } + + static matrix Zero() + requires (Rows != -1 && Cols != -1) + { + return parent::Zero(Rows, Cols); + } + + static matrix Zero(int rows, int cols) + { + if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size"); + if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size"); + return parent::Zero(rows, cols); + } + + static matrix Constant(Scalar value) + requires (Rows != -1 && Cols != -1) + { + return parent::Constant(Rows, Cols, static_cast(value)); + } + + static matrix Constant(int rows, int cols, Scalar value) + { + if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size"); + if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size"); + return parent::Constant(rows, cols, static_cast(value)); + } + + static matrix Random() + requires (Rows != -1 && Cols != -1) + { + return parent::Random(Rows, Cols); + } + + static matrix Random(int rows, int cols) + { + if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size"); + if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size"); + return parent::Random(rows, cols); + } + + static matrix Identity() + requires (Rows != -1 && Cols != -1) + { + return parent::Identity(Rows, Cols); + } + + static matrix Identity(int rows, int cols) + { + if constexpr(Rows != -1) assert(rows == Rows && "Mismatched between dynamic and static row size"); + if constexpr(Cols != -1) assert(cols == Cols && "Mismatched between dynamic and static column size"); + return parent::Identity(rows, cols); + } }; template diff --git a/src/matrix_impl64.cpp b/src/matrix_impl64.cpp index f70259a..e63f076 100644 --- a/src/matrix_impl64.cpp +++ b/src/matrix_impl64.cpp @@ -15,12 +15,14 @@ namespace rotgen //================================================================================================ struct matrix_impl64::payload { - Eigen::Matrix data; + using data_type = Eigen::Matrix; + + data_type data; payload(std::size_t r=0, std::size_t c=0) : data(r, c) {} payload(std::initializer_list> init) : data(init) {} + payload(data_type&& matrix) : data(std::move(matrix)) {} }; - //================================================================================================== // Constructors & Special Members //================================================================================================== @@ -169,4 +171,37 @@ namespace rotgen storage_->data /= s; return *this; } + + //================================================================================================== + //================================================================================================== + // Static functions + //================================================================================================== + + matrix_impl64 matrix_impl64::Zero(std::size_t rows, std::size_t cols) { + matrix_impl64 m; + m.storage_ = std::make_unique(payload::data_type::Zero(rows, cols)); + return m; + } + + matrix_impl64 matrix_impl64::Constant(std::size_t rows, std::size_t cols, double value) + { + matrix_impl64 m; + m.storage_ = std::make_unique(payload::data_type::Constant(rows, cols, value)); + return m; + } + + matrix_impl64 matrix_impl64::Random(std::size_t rows, std::size_t cols) + { + matrix_impl64 m; + m.storage_ = std::make_unique(payload::data_type::Random(rows, cols)); + return m; + } + + matrix_impl64 matrix_impl64::Identity(std::size_t rows, std::size_t cols) + { + matrix_impl64 m; + m.storage_ = std::make_unique(payload::data_type::Identity(rows, cols)); + return m; + } + } \ No newline at end of file diff --git a/test/basic/static_functions.cpp b/test/basic/static_functions.cpp new file mode 100644 index 0000000..4330b63 --- /dev/null +++ b/test/basic/static_functions.cpp @@ -0,0 +1,96 @@ +//================================================================================================== +/* + ROTGEN - Runtime Overlay for Eigen + Copyright : CODE RECKONS + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#define TTS_MAIN +#include +#include "tts.hpp" + +template +void test_zero(const MatrixType& matrix, std::size_t rows, std::size_t cols) +{ + for(std::size_t r=0;r +void test_constant(const MatrixType& matrix, std::size_t rows, std::size_t cols, double constant) +{ + for(std::size_t r=0;r +void test_random(const MatrixType& matrix, std::size_t rows, std::size_t cols) +{ + for(std::size_t r=0;r +void test_identity(const MatrixType& matrix, std::size_t rows, std::size_t cols) +{ + for(std::size_t r=0;r::Zero(), 3, 4); + test_zero(rotgen::matrix::Zero(), 1, 1); + test_zero(rotgen::matrix::Zero(), 10, 10); + test_zero(rotgen::matrix::Zero(3, 4), 3, 4); + test_zero(rotgen::matrix::Zero(7, 5), 7, 5); + test_zero(rotgen::matrix::Zero(9, 3), 9, 3); + test_zero(rotgen::matrix::Zero(2, 3), 2, 3); +}; + +TTS_CASE("Test constant") +{ + test_constant(rotgen::matrix::Constant(5.12), 3, 8, 5.12); + test_constant(rotgen::matrix::Constant(2.2), 1, 1, 2.2); + test_constant(rotgen::matrix::Constant(13), 11, 12, 13); + test_constant(rotgen::matrix::Constant(2, 7, 5.6), 2, 7, 5.6); + test_constant(rotgen::matrix::Constant(2, 2, 2.0), 2, 2, 2.0); + test_constant(rotgen::matrix::Constant(9, 3, 1.1), 9, 3, 1.1); + test_constant(rotgen::matrix::Constant(5, 9, 42), 5, 9, 42); +}; + +TTS_CASE("Test random") +{ + test_random(rotgen::matrix::Random(), 2, 3); + test_random(rotgen::matrix::Random(), 1, 1); + test_random(rotgen::matrix::Random(), 11, 17); + test_random(rotgen::matrix::Random(7, 3), 7, 3); + test_random(rotgen::matrix::Random(2, 2), 2, 2); + test_random(rotgen::matrix::Random(4, 3), 4, 3); + test_random(rotgen::matrix::Random(5, 5), 5, 5); +}; + +TTS_CASE("Test identity") +{ + test_identity(rotgen::matrix::Identity(), 4, 5); + test_identity(rotgen::matrix::Identity(), 1, 1); + test_identity(rotgen::matrix::Identity(), 21, 3); + test_identity(rotgen::matrix::Identity(2, 7), 2, 7); + test_identity(rotgen::matrix::Identity(2, 2), 2, 2); + test_identity(rotgen::matrix::Identity(3, 3), 3, 3); + test_identity(rotgen::matrix::Identity(5, 11), 5, 11); +};