diff --git a/include/rotgen/impl/matrix_impl64.hpp b/include/rotgen/impl/matrix_impl64.hpp index b031e3b..b61c544 100644 --- a/include/rotgen/impl/matrix_impl64.hpp +++ b/include/rotgen/impl/matrix_impl64.hpp @@ -35,6 +35,7 @@ namespace rotgen matrix_impl64& operator*=(double d); friend std::ostream& operator<<(std::ostream&,matrix_impl64 const&); + friend bool operator==(matrix_impl64 const& lhs, matrix_impl64 const& rhs); private: struct payload; diff --git a/include/rotgen/matrix.hpp b/include/rotgen/matrix.hpp index c5bf594..fac61df 100644 --- a/include/rotgen/matrix.hpp +++ b/include/rotgen/matrix.hpp @@ -11,14 +11,21 @@ namespace rotgen { - template + template< typename Scalar, int Rows = -1 , int Cols = -1 + , int Options = 0, int MaxRows = Rows, int MaxCols = Cols + > class matrix : public matrix_impl64 { using parent = matrix_impl64; public: - matrix() : parent(Rows==-1?0:Rows,Cols==-1?0:Cols) {} - matrix(std::size_t r, std::size_t c) : parent(r,c) {} + matrix() : parent(Rows==-1?0:Rows,Cols==-1?0:Cols) {} + matrix(std::size_t r, std::size_t c) : parent(r,c) {} + + friend bool operator==(matrix const& lhs, matrix const& rhs) + { + return static_cast(lhs) == static_cast(rhs); + } matrix& operator+=(matrix const& rhs) { @@ -31,19 +38,38 @@ namespace rotgen static_cast(*this) *= static_cast(rhs); return *this; } + + matrix& operator*=(double rhs) + { + static_cast(*this) *= rhs; + return *this; + } }; - template - matrix operator+(matrix const& lhs, matrix const& rhs) + template + matrix operator+(matrix const& lhs, matrix const& rhs) { - matrix that(lhs); + matrix that(lhs); return that += rhs; } - template - matrix operator*(matrix const& lhs, matrix const& rhs) + template + matrix operator*(matrix const& lhs, matrix const& rhs) { - matrix that(lhs); + matrix that(lhs); return that *= rhs; } + + template + matrix operator*(matrix const& lhs, double rhs) + { + matrix that(lhs); + return that *= rhs; + } + + template + matrix operator*(double lhs, matrix const& rhs) + { + return rhs * lhs; + } } \ No newline at end of file diff --git a/test/basic/io.cpp b/test/basic/io.cpp index 3a4e5ee..e1a0af8 100644 --- a/test/basic/io.cpp +++ b/test/basic/io.cpp @@ -21,6 +21,6 @@ TTS_CASE("Sample test") "0 0 0 0 0\n" "0 0 0 0 0\n" "0 0 0 0 0"; - + TTS_EQUAL(os.str(), ref); }; \ No newline at end of file diff --git a/test/basic/operators.cpp b/test/basic/operators.cpp new file mode 100644 index 0000000..7afcd1a --- /dev/null +++ b/test/basic/operators.cpp @@ -0,0 +1,31 @@ +//================================================================================================== +/* + ROTGEN - Runtime Overlay for Eigen + Copyright : CODE RECKONS + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#define TTS_MAIN +#include +#include "tts.hpp" + +TTS_CASE("Check operator*") +{ + rotgen::matrix a(2,2); + rotgen::matrix ref(2,2); + + for(int r=0;r