shape.cpp (3219B)
1 /** 2 * KFR (https://www.kfrlib.com) 3 * Copyright (C) 2016-2023 Dan Cazarin 4 * See LICENSE.txt for details 5 */ 6 7 #include <kfr/base/shape.hpp> 8 9 namespace kfr 10 { 11 inline namespace CMT_ARCH_NAME 12 { 13 14 TEST(shape) 15 { 16 using internal_generic::increment_indices_return; 17 using internal_generic::null_index; 18 CHECK(size_of_shape(shape{ 4, 3 }) == 12); 19 CHECK(size_of_shape(shape{ 1 }) == 1); 20 CHECK(size_of_shape<1>(1) == 1); 21 shape<1> sh1 = 1; 22 sh1 = 2; 23 24 CHECK(internal_generic::strides_for_shape(shape{ 2, 3, 4 }) == shape{ 12, 4, 1 }); 25 26 CHECK(internal_generic::strides_for_shape(shape{ 2, 3, 4 }, 10) == shape{ 120, 40, 10 }); 27 28 CHECK(increment_indices_return(shape{ 0, 0, 0 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) == shape{ 0, 0, 1 }); 29 CHECK(increment_indices_return(shape{ 0, 0, 3 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) == shape{ 0, 1, 0 }); 30 CHECK(increment_indices_return(shape{ 0, 2, 0 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) == shape{ 0, 2, 1 }); 31 CHECK(increment_indices_return(shape{ 0, 2, 3 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) == shape{ 1, 0, 0 }); 32 CHECK(increment_indices_return(shape{ 1, 2, 3 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) == 33 shape{ null_index, null_index, null_index }); 34 35 CHECK(shape{ 3, 4, 5 }.to_flat(shape{ 0, 0, 0 }) == 0); 36 CHECK(shape{ 3, 4, 5 }.to_flat(shape{ 2, 3, 4 }) == 59); 37 38 CHECK(shape{ 3, 4, 5 }.from_flat(0) == shape{ 0, 0, 0 }); 39 CHECK(shape{ 3, 4, 5 }.from_flat(59) == shape{ 2, 3, 4 }); 40 41 CHECK(shape{ 3, 4, 5 }.transpose() == shape{ 5, 4, 3 }); 42 } 43 TEST(shape_broadcast) 44 { 45 using internal_generic::can_assign_from; 46 using internal_generic::common_shape; 47 using internal_generic::same_layout; 48 49 CHECK(common_shape(shape{ 1, 5 }, shape{ 5, 1 }) == shape{ 5, 5 }); 50 CHECK(common_shape(shape{ 5 }, shape{ 5, 1 }) == shape{ 5, 5 }); 51 CHECK(common_shape(shape{ 1, 1, 1 }, shape{ 2, 5, 1 }) == shape{ 2, 5, 1 }); 52 CHECK(common_shape(shape{ 1 }, shape{ 2, 5, 7 }) == shape{ 2, 5, 7 }); 53 54 CHECK(common_shape(shape{}, shape{ 0 }) == shape{ 0 }); 55 CHECK(common_shape(shape{}, shape{ 0, 0 }) == shape{ 0, 0 }); 56 CHECK(common_shape(shape{ 0 }, shape{ 0, 0 }) == shape{ 0, 0 }); 57 58 CHECK(common_shape<true>(shape{}, shape{ 0 }) == shape{ 0 }); 59 CHECK(common_shape<true>(shape{}, shape{ 0, 0 }) == shape{ 0, 0 }); 60 CHECK(common_shape<true>(shape{ 0 }, shape{ 0, 0 }) == shape{ 0, 0 }); 61 62 CHECK(can_assign_from(shape{ 1, 4 }, shape{ 1, 4 })); 63 CHECK(!can_assign_from(shape{ 1, 4 }, shape{ 4, 1 })); 64 CHECK(can_assign_from(shape{ 1, 4 }, shape{ 1, 1 })); 65 CHECK(can_assign_from(shape{ 1, 4 }, shape{ 1 })); 66 CHECK(can_assign_from(shape{ 1, 4 }, shape{})); 67 68 CHECK(same_layout(shape{ 2, 3, 4 }, shape{ 2, 3, 4 })); 69 CHECK(same_layout(shape{ 1, 2, 3, 4 }, shape{ 2, 3, 4 })); 70 CHECK(same_layout(shape{ 2, 3, 4 }, shape{ 2, 1, 1, 3, 4 })); 71 CHECK(same_layout(shape{ 2, 3, 4 }, shape{ 2, 3, 4, 1 })); 72 CHECK(same_layout(shape{ 2, 1, 3, 4 }, shape{ 1, 2, 3, 4, 1 })); 73 74 CHECK(!same_layout(shape{ 2, 1, 3, 4 }, shape{ 1, 2, 4, 3, 1 })); 75 CHECK(!same_layout(shape{ 2, 1, 3, 4 }, shape{ 1, 2, 4, 3, 0 })); 76 } 77 } // namespace CMT_ARCH_NAME 78 } // namespace kfr