reduce.cpp (3013B)
1 /** 2 * KFR (https://www.kfrlib.com) 3 * Copyright (C) 2016-2023 Dan Cazarin 4 * See LICENSE.txt for details 5 */ 6 7 #include "kfr/base/tensor.hpp" 8 #include <kfr/base/reduce.hpp> 9 #include <kfr/base/simd_expressions.hpp> 10 #include <kfr/base/univector.hpp> 11 12 namespace kfr 13 { 14 inline namespace CMT_ARCH_NAME 15 { 16 17 TEST(reduce) 18 { 19 testo::epsilon_scope<void> e(100); 20 { 21 univector<float, 6> a({ 1, 2, 3, 4, 5, -9 }); 22 CHECK(sum(a) == 6); 23 CHECK(mean(a) == 1); 24 CHECK(minof(a) == -9); 25 CHECK(maxof(a) == 5); 26 CHECK(absminof(a) == 1); 27 CHECK(absmaxof(a) == 9); 28 CHECK(sumsqr(a) == 136); 29 CHECK(rms(a) == 4.760952285695233f); 30 CHECK(product(a) == -1080); 31 } 32 { 33 univector<double, 6> a({ 1, 2, 3, 4, 5, -9 }); 34 CHECK(sum(a) == 6); 35 CHECK(mean(a) == 1); 36 CHECK(minof(a) == -9); 37 CHECK(maxof(a) == 5); 38 CHECK(absminof(a) == 1); 39 CHECK(absmaxof(a) == 9); 40 CHECK(sumsqr(a) == 136); 41 CHECK(rms(a) == 4.760952285695233); 42 CHECK(product(a) == -1080); 43 } 44 { 45 univector<int, 6> a({ 1, 2, 3, 4, 5, -9 }); 46 CHECK(sum(a) == 6); 47 CHECK(mean(a) == 1); 48 CHECK(minof(a) == -9); 49 CHECK(maxof(a) == 5); 50 CHECK(absminof(a) == 1); 51 CHECK(absmaxof(a) == 9); 52 CHECK(sumsqr(a) == 136); 53 CHECK(product(a) == -1080); 54 } 55 } 56 57 TEST(dotproduct) 58 { 59 univector<float, 177> v1 = counter(); 60 univector<float, 177> v2 = counter() * 2 + 10; 61 CHECK(dotproduct(v1, v2) == 3821312); 62 } 63 64 TEST(histogram) 65 { 66 univector<int, 16> v{ 1, 9, 5, 2, 1, -3, 100, 19, -4, -3, 1, 5, 9, 8, 0, 1 }; 67 auto h = histogram<10>(v); 68 CHECK(h.total() == 16); 69 CHECK(h.below() == 3); 70 CHECK(h.above() == 2); 71 CHECK(h[0] == 1); 72 CHECK(h[1] == 4); 73 CHECK(h[2] == 1); 74 CHECK(h[3] == 0); 75 CHECK(h[4] == 0); 76 CHECK(h[5] == 2); 77 CHECK(h[6] == 0); 78 CHECK(h[7] == 0); 79 CHECK(h[8] == 1); 80 CHECK(h[9] == 2); 81 82 univector<double, 16> v2{ 0.1, 0.9, 0.5, 0.2, 0.1, -0.3, 10.0, 1.9, 83 -0.4, -0.3, 0.1, 0.5, 0.9, 0.8, 0.0, 0.1 }; 84 auto h2 = histogram<10>(v2); 85 CHECK(h2.total() == 16); 86 CHECK(h2.below() == 3); 87 CHECK(h2.above() == 2); 88 CHECK(h2[0] == 1); 89 CHECK(h2[1] == 4); 90 CHECK(h2[2] == 1); 91 CHECK(h2[3] == 0); 92 CHECK(h2[4] == 0); 93 CHECK(h2[5] == 2); 94 CHECK(h2[6] == 0); 95 CHECK(h2[7] == 0); 96 CHECK(h2[8] == 1); 97 CHECK(h2[9] == 2); 98 } 99 100 TEST(reduce_multidim) 101 { 102 CHECK(sum(tensor<int, 2>(shape{ 3, 3 }, { 1, 2, 3, 4, 5, 6, 7, 8, 9 })) == 45); // 103 CHECK(sum(tensor<int, 3>(shape{ 2, 2, 2 }, { 1, 2, 3, 4, 5, 6, 7, 8 })) == 36); // 104 105 CHECK(maxof(tensor<int, 3>(shape{ 2, 2, 2 }, { 1, 2, 3, 4, 5, 6, 7, 8 })) == 8); // 106 CHECK(minof(tensor<int, 3>(shape{ 2, 2, 2 }, { 1, 2, 3, 4, 5, 6, 7, 8 })) == 1); // 107 CHECK(product(tensor<int, 3>(shape{ 2, 2, 2 }, { 1, 2, 3, 4, 5, 6, 7, 8 })) == 40320); // 108 } 109 110 } // namespace CMT_ARCH_NAME 111 } // namespace kfr