diff --git a/include/rotgen/common/reshaper.hpp b/include/rotgen/common/reshaper.hpp new file mode 100644 index 0000000..3bc460a --- /dev/null +++ b/include/rotgen/common/reshaper.hpp @@ -0,0 +1,159 @@ +//================================================================================================== +/* + ROTGEN - Runtime Overlay for Eigen + Copyright : CODE RECKONS + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#pragma once + +namespace rotgen +{ + template struct rowwise_adaptor + { + using concrete_type = typename std::remove_cvref_t::concrete_type; + Ref& target_; + + concrete_type sum() const + { + concrete_type res(target_.rows(),1); + apply([&](auto r, auto i){ res(i) = r.sum(); }); + return res; + } + + concrete_type mean() const + { + concrete_type res(target_.rows(),1); + apply([&](auto r, auto i){ res(i) = r.mean(); }); + return res; + } + + concrete_type prod() const + { + concrete_type res(target_.rows(),1); + apply([&](auto r, auto i){ res(i) = r.prod(); }); + return res; + } + + concrete_type maxCoeff() const + { + concrete_type res(target_.rows(),1); + apply([&](auto r, auto i){ res(i) = r.maxCoeff(); }); + return res; + } + + concrete_type minCoeff() const + { + concrete_type res(target_.rows(),1); + apply([&](auto r, auto i){ res(i) = r.minCoeff(); }); + return res; + } + + concrete_type squaredNorm() const + { + concrete_type res(target_.rows(),1); + apply([&](auto r, auto i){ res(i) = r.squaredNorm(); }); + return res; + } + + concrete_type norm() const + { + concrete_type res(target_.rows(),1); + apply([&](auto r, auto i){ res(i) = r.norm(); }); + return res; + } + + private: + template void apply(Func f) + { + for(Index i = 0; i < target_.rows(); ++i) + f(row(target_,i), i); + } + + template void apply(Func f) const + { + for(Index i = 0; i < target_.rows(); ++i) + f(row(target_,i), i); + } + }; + + template struct colwise_adaptor + { + using concrete_type = typename std::remove_cvref_t::concrete_type; + Ref& target_; + + concrete_type sum() const + { + concrete_type res(1, target_.cols()); + apply([&](auto r, auto i){ res(i) = r.sum(); }); + return res; + } + + concrete_type mean() const + { + concrete_type res(1, target_.cols()); + apply([&](auto r, auto i){ res(i) = r.mean(); }); + return res; + } + + concrete_type prod() const + { + concrete_type res(1, target_.cols()); + apply([&](auto r, auto i){ res(i) = r.prod(); }); + return res; + } + + concrete_type maxCoeff() const + { + concrete_type res(1, target_.cols()); + apply([&](auto r, auto i){ res(i) = r.maxCoeff(); }); + return res; + } + + concrete_type minCoeff() const + { + concrete_type res(1, target_.cols()); + apply([&](auto r, auto i){ res(i) = r.minCoeff(); }); + return res; + } + + concrete_type squaredNorm() const + { + concrete_type res(1, target_.cols()); + apply([&](auto r, auto i){ res(i) = r.squaredNorm(); }); + return res; + } + + concrete_type norm() const + { + concrete_type res(1, target_.cols()); + apply([&](auto r, auto i){ res(i) = r.norm(); }); + return res; + } + + private: + template void apply(Func f) + { + for(Index i = 0; i < target_.cols(); ++i) + f(col(target_,i), i); + } + + template void apply(Func f) const + { + for(Index i = 0; i < target_.cols(); ++i) + f(col(target_,i), i); + } + }; + + template auto rowwise(T&& t) + { + if constexpr(use_expression_templates) return t.base().rowwise(); + else return rowwise_adaptor{t}; + } + + template auto colwise(T&& t) + { + if constexpr(use_expression_templates) return t.base().colwise(); + else return colwise_adaptor{t}; + } +} \ No newline at end of file diff --git a/include/rotgen/rotgen.hpp b/include/rotgen/rotgen.hpp index bebb7e9..9207428 100644 --- a/include/rotgen/rotgen.hpp +++ b/include/rotgen/rotgen.hpp @@ -26,5 +26,6 @@ #include #include #include +#include #include #include diff --git a/test/unit/functions/rowwise.cpp b/test/unit/functions/rowwise.cpp new file mode 100644 index 0000000..a356f2c --- /dev/null +++ b/test/unit/functions/rowwise.cpp @@ -0,0 +1,54 @@ +//================================================================================================== +/* + ROTGEN - Runtime Overlay for Eigen + Copyright : CODE RECKONS + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#include "unit/tests.hpp" +#include + +TTS_CASE_TPL("rowwise API", rotgen::tests::types) +( tts::type< tts::types> ) +{ + using e_t = Eigen::Matrix; + e_t ref = e_t::Random(4,4); + auto ref_rw = ref.rowwise(); + + rotgen::matrix mat(4,4); + rotgen::tests::prepare([&](auto r, auto c) { return ref(r,c); }, mat); + + auto rw = rotgen::rowwise(mat); + + for(rotgen::Index i=0;i( tts::type< tts::types> ) +{ + using e_t = Eigen::Matrix; + e_t ref = e_t::Random(4,4); + auto ref_rw = ref.colwise(); + + rotgen::matrix mat(4,4); + rotgen::tests::prepare([&](auto r, auto c) { return ref(r,c); }, mat); + + auto rw = rotgen::colwise(mat); + + for(rotgen::Index i=0;i