shuffle.cpp (10020B)
1 /** 2 * KFR (https://www.kfrlib.com) 3 * Copyright (C) 2016-2023 Dan Cazarin 4 * See LICENSE.txt for details 5 */ 6 7 #include <kfr/io.hpp> 8 #include <kfr/simd/shuffle.hpp> 9 10 namespace kfr 11 { 12 inline namespace CMT_ARCH_NAME 13 { 14 TEST(concat) 15 { 16 CHECK(concat(vec<f32, 1>{ 1 }, vec<f32, 2>{ 2, 3 }, vec<f32, 1>{ 4 }, vec<f32, 3>{ 5, 6, 7 }) // 17 == vec<f32, 7>{ 1, 2, 3, 4, 5, 6, 7 }); 18 } 19 20 TEST(reverse) 21 { 22 CHECK(reverse(pack(0, 1, 2, 3, 4, 5, 6, 7)) == pack(7, 6, 5, 4, 3, 2, 1, 0)); 23 CHECK(reverse<2>(pack(0, 1, 2, 3, 4, 5, 6, 7)) == pack(6, 7, 4, 5, 2, 3, 0, 1)); 24 CHECK(reverse<4>(pack(0, 1, 2, 3, 4, 5, 6, 7)) == pack(4, 5, 6, 7, 0, 1, 2, 3)); 25 } 26 27 TEST(shuffle) 28 { 29 const vec<int, 8> numbers1 = enumerate<int, 8>(); 30 const vec<int, 8> numbers2 = enumerate<int, 8, 100>(); 31 CHECK(shuffle(numbers1, numbers2, elements_t<0, 8, 2, 10, 4, 12, 6, 14>()) == 32 vec<int, 8>{ 0, 100, 2, 102, 4, 104, 6, 106 }); 33 CHECK(shuffle(numbers1, numbers2, elements_t<0, 8>()) == vec<int, 8>{ 0, 100, 2, 102, 4, 104, 6, 106 }); 34 } 35 36 TEST(permute) 37 { 38 const vec<int, 8> numbers1 = enumerate<int, 8>(); 39 CHECK(permute(numbers1, elements_t<0, 2, 1, 3, 4, 6, 5, 7>()) == vec<int, 8>{ 0, 2, 1, 3, 4, 6, 5, 7 }); 40 CHECK(permute(numbers1, elements_t<0, 2, 1, 3>()) == vec<int, 8>{ 0, 2, 1, 3, 4, 6, 5, 7 }); 41 } 42 43 TEST(blend) 44 { 45 const vec<int, 8> numbers1 = enumerate<int, 8>(); 46 const vec<int, 8> numbers2 = enumerate<int, 8, 100>(); 47 CHECK(blend(numbers1, numbers2, elements_t<0, 1, 1, 0, 1, 1, 0, 1>()) == 48 vec<int, 8>{ 0, 101, 102, 3, 104, 105, 6, 107 }); 49 CHECK(blend(numbers1, numbers2, elements_t<0, 1, 1>()) == 50 vec<int, 8>{ 0, 101, 102, 3, 104, 105, 6, 107 }); 51 } 52 53 TEST(duplicate_shuffle) 54 { 55 CHECK(dup(pack(0, 1, 2, 3)) == pack(0, 0, 1, 1, 2, 2, 3, 3)); 56 CHECK(duphalves(pack(0, 1, 2, 3)) == pack(0, 1, 2, 3, 0, 1, 2, 3)); 57 CHECK(dupeven(pack(0, 1, 2, 3, 4, 5, 6, 7)) == pack(0, 0, 2, 2, 4, 4, 6, 6)); 58 CHECK(dupodd(pack(0, 1, 2, 3, 4, 5, 6, 7)) == pack(1, 1, 3, 3, 5, 5, 7, 7)); 59 } 60 61 TEST(split_interleave) 62 { 63 vec<f32, 1> a1; 64 vec<f32, 2> a23; 65 vec<f32, 1> a4; 66 vec<f32, 3> a567; 67 split(vec<f32, 7>{ 1, 2, 3, 4, 5, 6, 7 }, a1, a23, a4, a567); 68 CHECK(a1 == vec<f32, 1>{ 1 }); 69 CHECK(a23 == vec<f32, 2>{ 2, 3 }); 70 CHECK(a4 == vec<f32, 1>{ 4 }); 71 CHECK(a567 == vec<f32, 3>{ 5, 6, 7 }); 72 73 CHECK(splitpairs(pack(0, 1, 2, 3, 4, 5, 6, 7)) == pack(0, 2, 4, 6, 1, 3, 5, 7)); 74 CHECK(splitpairs<2>(pack(0, 1, 2, 3, 4, 5, 6, 7)) == pack(0, 1, 4, 5, 2, 3, 6, 7)); 75 76 CHECK(interleavehalves(pack(0, 1, 2, 3, 4, 5, 6, 7)) == pack(0, 4, 1, 5, 2, 6, 3, 7)); 77 CHECK(interleavehalves<2>(pack(0, 1, 2, 3, 4, 5, 6, 7)) == pack(0, 1, 4, 5, 2, 3, 6, 7)); 78 } 79 80 TEST(zip) 81 { 82 CHECK(zip(pack(1, 2, 3, 4), pack(10, 20, 30, 40)) == 83 pack(pack(1, 10), pack(2, 20), pack(3, 30), pack(4, 40))); 84 85 CHECK(zip(pack(1, 2, 3, 4), pack(10, 20, 30, 40), pack(111, 222, 333, 444), pack(-1, -2, -3, -4)) == 86 pack(pack(1, 10, 111, -1), pack(2, 20, 222, -2), pack(3, 30, 333, -3), pack(4, 40, 444, -4))); 87 } 88 89 TEST(column) 90 { 91 CHECK(column<1>(pack(pack(0, 1), pack(2, 3), pack(4, 5), pack(6, 7))) == pack(1, 3, 5, 7)); 92 93 CHECK(column<0>(pack(pack(0., 1.), pack(2., 3.), pack(4., 5.), pack(6., 7.))) == pack(0., 2., 4., 6.)); 94 } 95 96 TEST(broadcast) 97 { 98 CHECK(broadcast<8>(1) == pack(1, 1, 1, 1, 1, 1, 1, 1)); 99 CHECK(broadcast<8>(1, 2) == pack(1, 2, 1, 2, 1, 2, 1, 2)); 100 CHECK(broadcast<8>(1, 2, 3, 4) == pack(1, 2, 3, 4, 1, 2, 3, 4)); 101 CHECK(broadcast<8>(1, 2, 3, 4, 5, 6, 7, 8) == pack(1, 2, 3, 4, 5, 6, 7, 8)); 102 103 CHECK(broadcast<5>(3.f) == vec<f32, 5>{ 3, 3, 3, 3, 3 }); 104 CHECK(broadcast<6>(1.f, 2.f) == vec<f32, 6>{ 1, 2, 1, 2, 1, 2 }); 105 CHECK(broadcast<6>(1.f, 2.f, 3.f) == vec<f32, 6>{ 1, 2, 3, 1, 2, 3 }); 106 } 107 108 TEST(resize) 109 { 110 CHECK(resize<5>(make_vector(3.f)) == vec<f32, 5>{ 3, 3, 3, 3, 3 }); 111 CHECK(resize<6>(make_vector(1.f, 2.f)) == vec<f32, 6>{ 1, 2, 1, 2, 1, 2 }); 112 CHECK(resize<6>(make_vector(1.f, 2.f, 3.f)) == vec<f32, 6>{ 1, 2, 3, 1, 2, 3 }); 113 } 114 115 TEST(make_vector) 116 { 117 const signed char ch = -1; 118 CHECK(make_vector(1, 2, ch) == vec<i32, 3>{ 1, 2, -1 }); 119 const i64 v = -100; 120 CHECK(make_vector(1, 2, v) == vec<i64, 3>{ 1, 2, -100 }); 121 CHECK(make_vector<i64>(1, 2, ch) == vec<i64, 3>{ 1, 2, -1 }); 122 CHECK(make_vector<f32>(1, 2, ch) == vec<f32, 3>{ 1, 2, -1 }); 123 124 CHECK(make_vector(f64x2{ 1, 2 }, f64x2{ 10, 20 }) == 125 vec<vec<f64, 2>, 2>{ f64x2{ 1, 2 }, f64x2{ 10, 20 } }); 126 CHECK(make_vector(1.f, f32x2{ 10, 20 }) == vec<vec<f32, 2>, 2>{ f32x2{ 1, 1 }, f32x2{ 10, 20 } }); 127 } 128 129 TEST(zerovector) 130 { 131 CHECK(zerovector<f32, 3>() == f32x3{ 0, 0, 0 }); 132 // CHECK(zerovector<i16, 3>() == i16x3{ 0, 0, 0 }); // clang 3.9 (trunk) crashes here 133 CHECK(zerovector(f64x8{}) == f64x8{ 0, 0, 0, 0, 0, 0, 0, 0 }); 134 } 135 136 TEST(allonesvector) 137 { 138 CHECK(bitcast<u32>(special_constants<f32>::allones()) == 0xFFFFFFFFu); 139 CHECK(bitcast<u64>(special_constants<f64>::allones()) == 0xFFFFFFFFFFFFFFFFull); 140 141 CHECK(allonesvector<i16, 3>() == i16x3{ -1, -1, -1 }); 142 CHECK(allonesvector<u8, 3>() == u8x3{ 255, 255, 255 }); 143 } 144 145 TEST(transpose) 146 { 147 const auto sixteen = enumerate<float, 16>(); 148 CHECK(transpose<4>(sixteen) == vec<float, 16>(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15)); 149 } 150 151 TEST(odd_even) 152 { 153 CHECK(even(pack(0, 1, 2, 3, 4, 5, 6, 7)) == pack(0, 2, 4, 6)); 154 CHECK(odd(pack(0, 1, 2, 3, 4, 5, 6, 7)) == pack(1, 3, 5, 7)); 155 156 CHECK(even<2>(pack(0, 1, 2, 3, 4, 5, 6, 7)) == pack(0, 1, 4, 5)); 157 CHECK(odd<2>(pack(0, 1, 2, 3, 4, 5, 6, 7)) == pack(2, 3, 6, 7)); 158 } 159 160 TEST(low_high) 161 { 162 CHECK(low(vec<u8, 8>(1, 2, 3, 4, 5, 6, 7, 8)) == vec<u8, 4>(1, 2, 3, 4)); 163 CHECK(high(vec<u8, 8>(1, 2, 3, 4, 5, 6, 7, 8)) == vec<u8, 4>(5, 6, 7, 8)); 164 165 CHECK(low(vec<u8, 7>(1, 2, 3, 4, 5, 6, 7)) == vec<u8, 4>(1, 2, 3, 4)); 166 CHECK(high(vec<u8, 7>(1, 2, 3, 4, 5, 6, 7)) == vec<u8, 3>(5, 6, 7)); 167 168 CHECK(low(vec<u8, 6>(1, 2, 3, 4, 5, 6)) == vec<u8, 4>(1, 2, 3, 4)); 169 CHECK(high(vec<u8, 6>(1, 2, 3, 4, 5, 6)) == vec<u8, 2>(5, 6)); 170 171 CHECK(low(vec<u8, 5>(1, 2, 3, 4, 5)) == vec<u8, 4>(1, 2, 3, 4)); 172 CHECK(high(vec<u8, 5>(1, 2, 3, 4, 5)) == vec<u8, 1>(5)); 173 174 CHECK(low(vec<u8, 4>(1, 2, 3, 4)) == vec<u8, 2>(1, 2)); 175 CHECK(high(vec<u8, 4>(1, 2, 3, 4)) == vec<u8, 2>(3, 4)); 176 177 CHECK(low(vec<u8, 3>(1, 2, 3)) == vec<u8, 2>(1, 2)); 178 CHECK(high(vec<u8, 3>(1, 2, 3)) == vec<u8, 1>(3)); 179 180 CHECK(low(vec<u8, 2>(1, 2)) == vec<u8, 1>(1)); 181 CHECK(high(vec<u8, 2>(1, 2)) == vec<u8, 1>(2)); 182 } 183 TEST(enumerate) 184 { 185 CHECK(enumerate(vec_shape<int, 4>{}, 4) == vec{ 0, 4, 8, 12 }); 186 CHECK(enumerate(vec_shape<int, 8>{}, 3) == vec{ 0, 3, 6, 9, 12, 15, 18, 21 }); 187 CHECK(enumerate(vec_shape<int, 7>{}, 3) == vec{ 0, 3, 6, 9, 12, 15, 18 }); 188 } 189 190 191 TEST(test_basic) 192 { 193 // How to make a vector: 194 195 // * Use constructor 196 const vec<double, 4> first{ 1, 2.5, -infinity, 3.1415926 }; 197 CHECK(first == vec<double, 4>{ 1, 2.5, -infinity, 3.1415926 }); 198 199 // * Use make_vector function 200 const auto second = make_vector(-1, +1); 201 CHECK(second == vec<int, 2>{ -1, 1 }); 202 203 // * Convert from vector of other type: 204 const vec<int, 4> int_vector{ 10, 20, 30, 40 }; 205 const vec<double, 4> double_vector = cast<double>(int_vector); 206 CHECK(double_vector == vec<double, 4>{ 10, 20, 30, 40 }); 207 208 // * Concat two vectors: 209 const vec<int, 1> left_part{ 1 }; 210 const vec<int, 1> right_part{ 2 }; 211 const vec<int, 2> pair{ left_part, right_part }; 212 CHECK(pair == vec<int, 2>{ 1, 2 }); 213 214 // * Same, but using make_vector and concat: 215 const vec<int, 2> pair2 = concat(make_vector(10), make_vector(20)); 216 CHECK(pair2 == vec<int, 2>{ 10, 20 }); 217 218 // * Repeat vector multiple times: 219 const vec<short, 8> repeated = repeat<4>(make_vector<short>(0, -1)); 220 CHECK(repeated == vec<short, 8>{ 0, -1, 0, -1, 0, -1, 0, -1 }); 221 222 // * Use enumerate to generate sequence of numbers: 223 const vec<int, 8> eight = enumerate<int, 8>(); 224 CHECK(eight == vec<int, 8>{ 0, 1, 2, 3, 4, 5, 6, 7 }); 225 226 // * Vectors can be of any length... 227 const vec<int, 1> one{ 42 }; 228 const vec<int, 2> two = concat(one, make_vector(42)); 229 CHECK(two == vec<int, 2>{ 42, 42 }); 230 231 const vec<u8, 256> very_long_vector = repeat<64>(make_vector<u8>(1, 2, 4, 8)); 232 CHECK(slice<0, 17>(very_long_vector) == 233 vec<unsigned char, 17>{ 1, 2, 4, 8, 1, 2, 4, 8, 1, 2, 4, 8, 1, 2, 4, 8, 1 }); 234 235 // * ...really any: 236 using big_vector = vec<i16, 107>; 237 big_vector v107 = enumerate<i16, 107>(); 238 CHECK(hadd(v107) == static_cast<short>(5671)); 239 240 using color = vec<u8, 3>; 241 const color green = cast<u8>(make_vector(0.0, 1.0, 0.0) * 255); 242 CHECK(green == vec<unsigned char, 3>{ 0, 255, 0 }); 243 244 // Vectors support all standard operators: 245 const auto op1 = make_vector(0, 1, 10, 100); 246 const auto op2 = make_vector(20, 2, -2, 200); 247 const auto result = op1 * op2 - 4; 248 CHECK(result == vec<int, 4>{ -4, -2, -24, 19996 }); 249 250 // * Transform vector: 251 const vec<int, 8> numbers1 = enumerate<int, 8>(); 252 const vec<int, 8> numbers2 = enumerate<int, 8>() + 100; 253 CHECK(odd(numbers1) == vec<int, 4>{ 1, 3, 5, 7 }); 254 CHECK(even(numbers2) == vec<int, 4>{ 100, 102, 104, 106 }); 255 256 CHECK(subadd(pack(0, 1, 2, 3, 4, 5, 6, 7), pack(10, 10, 10, 10, 10, 10, 10, 10)) == 257 pack(-10, 11, -8, 13, -6, 15, -4, 17)); 258 CHECK(addsub(pack(0, 1, 2, 3, 4, 5, 6, 7), pack(10, 10, 10, 10, 10, 10, 10, 10)) == 259 pack(10, -9, 12, -7, 14, -5, 16, -3)); 260 261 CHECK(digitreverse4(pack(0.f, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)) == 262 pack(0.f, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15)); 263 264 CHECK(inrange(pack(1, 2, 3), 1, 3) == make_mask<int>(true, true, true)); 265 CHECK(inrange(pack(1, 2, 3), 1, 2) == make_mask<int>(true, true, false)); 266 CHECK(inrange(pack(1, 2, 3), 1, 1) == make_mask<int>(true, false, false)); 267 } 268 } // namespace CMT_ARCH_NAME 269 } // namespace kfr