//================================================================================================== /* ROTGEN - Runtime Overlay for Eigen Copyright : CODE RECKONS SPDX-License-Identifier: BSL-1.0 */ //================================================================================================== #include #include "unit/tests.hpp" TTS_CASE_TPL("Chains of extraction", rotgen::tests::types) (tts::type>) { constexpr int N = 8; auto a = rotgen::matrix::Random(); auto b = topLeftCorner(a, 5, 5); TTS_EQUAL(b.startRow(), 0); TTS_EQUAL(b.startCol(), 0); setConstant(b, -7); for (rotgen::Index r = 0; r < 5; r++) for (rotgen::Index c = 0; c < 5; c++) TTS_EQUAL(a(r, c), -7); auto bb = bottomRightCorner(b, 3, 3); TTS_EQUAL(bb.startRow(), 2); TTS_EQUAL(bb.startCol(), 2); setConstant(bb, 42); for (rotgen::Index r = 2; r < 5; r++) for (rotgen::Index c = 2; c < 5; c++) TTS_EQUAL(a(r, c), 42); auto bbb = row(bb, 1); TTS_EQUAL(bbb.startRow(), 1); TTS_EQUAL(bbb.startCol(), 0); setConstant(bbb, 99.5); for (rotgen::Index c = 3; c < 5; c++) TTS_EQUAL(a(3, c), 99.5); auto bbbb = col(bbb, 1); TTS_EQUAL(bbbb.startRow(), 0); TTS_EQUAL(bbbb.startCol(), 1); setConstant(bbbb, 0.125); TTS_EQUAL(a(3, 3), 0.125); bool verbose = ::tts::arguments()[{"--verbose"}]; if (verbose) std::cout << a << "\n\n"; }; auto ref_extract(rotgen::ref> m) { return rotgen::extract(m, 0, 0, 3, 4); } auto ref_cextract(rotgen::ref const> m) { return rotgen::extract(m, 3, 4, 4, 3); } TTS_CASE("Extraction of ref/ref const") { auto m = rotgen::setRandom>(); auto extracted = ref_extract(m); extracted = rotgen::setOnes>(); for (rotgen::Index r = 0; r < 3; r++) for (rotgen::Index c = 0; c < 4; c++) TTS_EQUAL(m(r, c), 1.f); rotgen::extract(m, 3, 4, 4, 3) = rotgen::setConstant>(5); auto sliced = ref_cextract(m); for (rotgen::Index r = 0; r < 4; r++) for (rotgen::Index c = 0; c < 3; c++) TTS_EQUAL(sliced(r, c), 5.f); }; void process_col(rotgen::matrix& m, rotgen::matrix const& n, int c) { col(m, c) += n; } TTS_CASE("Compound operators on extractions") { rotgen::matrix m; rotgen::matrix reference; rotgen::matrix n; setZero(m); setConstant(n, 10); setConstant(reference, 10); for (int i = 0; i < m.cols(); i++) process_col(m, n, i); TTS_EQUAL(m, reference); }; TTS_CASE("Compatibility of 1D blocks") { rotgen::matrix V; setConstant(V, 99); rotgen::matrix C{12, 34, 56}; extract<1, 2>(V, 0, 0) = segment<2>(C, 0); extract<2, 1>(V, 1, 0) = segment<2>(C, 1); TTS_EQUAL(V(0, 0), C(0)); TTS_EQUAL(V(0, 1), C(1)); TTS_EQUAL(V(1, 0), C(1)); TTS_EQUAL(V(2, 0), C(2)); };