commit 956b4899cbfee4f0e5ab2725f0c61c1f42738ce1
parent c7cd3a9bcc73a10bbe0ba767d7bb6e0804ef8821
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Tue, 13 Nov 2018 19:41:11 +0300
AVX512 support
Diffstat:
11 files changed, 315 insertions(+), 53 deletions(-)
diff --git a/include/kfr/base/abs.hpp b/include/kfr/base/abs.hpp
@@ -64,6 +64,17 @@ KFR_SINTRIN u16avx abs(const u16avx& x) { return x; }
KFR_SINTRIN u8avx abs(const u8avx& x) { return x; }
#endif
+#if defined CMT_ARCH_AVX512
+KFR_SINTRIN i64avx512 abs(const i64avx512& x) { return select(x >= 0, x, -x); }
+KFR_SINTRIN i32avx512 abs(const i32avx512& x) { return _mm512_abs_epi32(*x); }
+KFR_SINTRIN i16avx512 abs(const i16avx512& x) { return _mm512_abs_epi16(*x); }
+KFR_SINTRIN i8avx512 abs(const i8avx512& x) { return _mm512_abs_epi8(*x); }
+KFR_SINTRIN u64avx512 abs(const u64avx512& x) { return x; }
+KFR_SINTRIN u32avx512 abs(const u32avx512& x) { return x; }
+KFR_SINTRIN u16avx512 abs(const u16avx512& x) { return x; }
+KFR_SINTRIN u8avx512 abs(const u8avx512& x) { return x; }
+#endif
+
KFR_HANDLE_ALL_SIZES_NOT_F_1(abs)
#elif defined CMT_ARCH_NEON && defined KFR_NATIVE_INTRINSICS
@@ -108,7 +119,7 @@ KFR_SINTRIN vec<T, N> abs(const vec<T, N>& x)
}
#endif
KFR_I_CONVERTER(abs)
-}
+} // namespace intrinsics
KFR_I_FN(abs)
/**
@@ -128,4 +139,4 @@ KFR_INTRIN internal::expression_function<fn::abs, E1> abs(E1&& x)
{
return { fn::abs(), std::forward<E1>(x) };
}
-}
+} // namespace kfr
diff --git a/include/kfr/base/function.hpp b/include/kfr/base/function.hpp
@@ -78,6 +78,17 @@ using u16avx = vec<u16, 16>;
using u32avx = vec<u32, 8>;
using u64avx = vec<u64, 4>;
+using f32avx512 = vec<f32, 16>;
+using f64avx512 = vec<f64, 8>;
+using i8avx512 = vec<i8, 64>;
+using i16avx512 = vec<i16, 32>;
+using i32avx512 = vec<i32, 16>;
+using i64avx512 = vec<i64, 8>;
+using u8avx512 = vec<u8, 64>;
+using u16avx512 = vec<u16, 32>;
+using u32avx512 = vec<u32, 16>;
+using u64avx512 = vec<u64, 8>;
+
#else
using f32neon = vec<f32, 4>;
using f64neon = vec<f64, 2>;
@@ -252,6 +263,6 @@ inline T to_scalar(const vec<T, 1>& value)
{
return value[0];
}
-}
-}
+} // namespace intrinsics
+} // namespace kfr
CMT_PRAGMA_GNU(GCC diagnostic pop)
diff --git a/include/kfr/base/logical.hpp b/include/kfr/base/logical.hpp
@@ -53,6 +53,7 @@ struct bitmask
#if defined CMT_ARCH_SSE41
+// horizontal OR
KFR_SINTRIN bool bittestany(const u8sse& x) { return !_mm_testz_si128(*x, *x); }
KFR_SINTRIN bool bittestany(const u16sse& x) { return !_mm_testz_si128(*x, *x); }
KFR_SINTRIN bool bittestany(const u32sse& x) { return !_mm_testz_si128(*x, *x); }
@@ -62,6 +63,7 @@ KFR_SINTRIN bool bittestany(const i16sse& x) { return !_mm_testz_si128(*x, *x);
KFR_SINTRIN bool bittestany(const i32sse& x) { return !_mm_testz_si128(*x, *x); }
KFR_SINTRIN bool bittestany(const i64sse& x) { return !_mm_testz_si128(*x, *x); }
+// horizontal AND
KFR_SINTRIN bool bittestall(const u8sse& x) { return _mm_testc_si128(*x, *allonesvector(x)); }
KFR_SINTRIN bool bittestall(const u16sse& x) { return _mm_testc_si128(*x, *allonesvector(x)); }
KFR_SINTRIN bool bittestall(const u32sse& x) { return _mm_testc_si128(*x, *allonesvector(x)); }
@@ -73,17 +75,13 @@ KFR_SINTRIN bool bittestall(const i64sse& x) { return _mm_testc_si128(*x, *allon
#endif
#if defined CMT_ARCH_AVX
+// horizontal OR
KFR_SINTRIN bool bittestany(const f32sse& x) { return !_mm_testz_ps(*x, *x); }
KFR_SINTRIN bool bittestany(const f64sse& x) { return !_mm_testz_pd(*x, *x); }
-KFR_SINTRIN bool bittestall(const f32sse& x) { return _mm_testc_ps(*x, *allonesvector(x)); }
-KFR_SINTRIN bool bittestall(const f64sse& x) { return _mm_testc_pd(*x, *allonesvector(x)); }
KFR_SINTRIN bool bittestany(const f32avx& x) { return !_mm256_testz_ps(*x, *x); }
KFR_SINTRIN bool bittestany(const f64avx& x) { return !_mm256_testz_pd(*x, *x); }
-KFR_SINTRIN bool bittestnall(const f32avx& x) { return _mm256_testc_ps(*x, *allonesvector(x)); }
-KFR_SINTRIN bool bittestnall(const f64avx& x) { return _mm256_testc_pd(*x, *allonesvector(x)); }
-
KFR_SINTRIN bool bittestany(const u8avx& x) { return !_mm256_testz_si256(*x, *x); }
KFR_SINTRIN bool bittestany(const u16avx& x) { return !_mm256_testz_si256(*x, *x); }
KFR_SINTRIN bool bittestany(const u32avx& x) { return !_mm256_testz_si256(*x, *x); }
@@ -93,6 +91,13 @@ KFR_SINTRIN bool bittestany(const i16avx& x) { return !_mm256_testz_si256(*x, *x
KFR_SINTRIN bool bittestany(const i32avx& x) { return !_mm256_testz_si256(*x, *x); }
KFR_SINTRIN bool bittestany(const i64avx& x) { return !_mm256_testz_si256(*x, *x); }
+// horizontal AND
+KFR_SINTRIN bool bittestall(const f32sse& x) { return _mm_testc_ps(*x, *allonesvector(x)); }
+KFR_SINTRIN bool bittestall(const f64sse& x) { return _mm_testc_pd(*x, *allonesvector(x)); }
+
+KFR_SINTRIN bool bittestall(const f32avx& x) { return _mm256_testc_ps(*x, *allonesvector(x)); }
+KFR_SINTRIN bool bittestall(const f64avx& x) { return _mm256_testc_pd(*x, *allonesvector(x)); }
+
KFR_SINTRIN bool bittestall(const u8avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); }
KFR_SINTRIN bool bittestall(const u16avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); }
KFR_SINTRIN bool bittestall(const u32avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); }
@@ -101,6 +106,34 @@ KFR_SINTRIN bool bittestall(const i8avx& x) { return _mm256_testc_si256(*x, *all
KFR_SINTRIN bool bittestall(const i16avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); }
KFR_SINTRIN bool bittestall(const i32avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); }
KFR_SINTRIN bool bittestall(const i64avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); }
+
+#if defined CMT_ARCH_AVX512
+// horizontal OR
+KFR_SINTRIN bool bittestany(const f32avx512& x) { return !_mm512_test_epi32_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const f64avx512& x) { return !_mm512_test_epi64_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const u8avx512& x) { return !_mm512_test_epi8_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const u16avx512& x) { return !_mm512_test_epi16_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const u32avx512& x) { return !_mm512_test_epi32_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const u64avx512& x) { return !_mm512_test_epi64_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const i8avx512& x) { return !_mm512_test_epi8_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const i16avx512& x) { return !_mm512_test_epi16_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const i32avx512& x) { return !_mm512_test_epi32_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const i64avx512& x) { return !_mm512_test_epi64_mask(*x, *x); }
+
+// horizontal AND
+KFR_SINTRIN bool bittestall(const f32avx512& x) { return ~bittestany(~x); }
+KFR_SINTRIN bool bittestall(const f64avx512& x) { return ~bittestany(~x); }
+KFR_SINTRIN bool bittestall(const u8avx512& x) { return ~bittestany(~x); }
+KFR_SINTRIN bool bittestall(const u16avx512& x) { return ~bittestany(~x); }
+KFR_SINTRIN bool bittestall(const u32avx512& x) { return ~bittestany(~x); }
+KFR_SINTRIN bool bittestall(const u64avx512& x) { return ~bittestany(~x); }
+KFR_SINTRIN bool bittestall(const i8avx512& x) { return ~bittestany(~x); }
+KFR_SINTRIN bool bittestall(const i16avx512& x) { return ~bittestany(~x); }
+KFR_SINTRIN bool bittestall(const i32avx512& x) { return ~bittestany(~x); }
+KFR_SINTRIN bool bittestall(const i64avx512& x) { return ~bittestany(~x); }
+
+#endif
+
#elif defined CMT_ARCH_SSE41
KFR_SINTRIN bool bittestany(const f32sse& x) { return !_mm_testz_si128(*bitcast<u8>(x), *bitcast<u8>(x)); }
KFR_SINTRIN bool bittestany(const f64sse& x) { return !_mm_testz_si128(*bitcast<u8>(x), *bitcast<u8>(x)); }
@@ -249,7 +282,7 @@ KFR_SINTRIN bool bittestall(const vec<T, N>& x, const vec<T, N>& y)
return !bittestany(~x & y);
}
#endif
-}
+} // namespace intrinsics
/**
* @brief Returns x[0] && x[1] && ... && x[N-1]
@@ -268,4 +301,4 @@ KFR_SINTRIN bool any(const mask<T, N>& x)
{
return intrinsics::bittestany(x.asvec());
}
-}
+} // namespace kfr
diff --git a/include/kfr/base/min_max.hpp b/include/kfr/base/min_max.hpp
@@ -42,15 +42,11 @@ KFR_SINTRIN f32sse min(const f32sse& x, const f32sse& y) { return _mm_min_ps(*x,
KFR_SINTRIN f64sse min(const f64sse& x, const f64sse& y) { return _mm_min_pd(*x, *y); }
KFR_SINTRIN u8sse min(const u8sse& x, const u8sse& y) { return _mm_min_epu8(*x, *y); }
KFR_SINTRIN i16sse min(const i16sse& x, const i16sse& y) { return _mm_min_epi16(*x, *y); }
-KFR_SINTRIN i64sse min(const i64sse& x, const i64sse& y) { return select(x < y, x, y); }
-KFR_SINTRIN u64sse min(const u64sse& x, const u64sse& y) { return select(x < y, x, y); }
KFR_SINTRIN f32sse max(const f32sse& x, const f32sse& y) { return _mm_max_ps(*x, *y); }
KFR_SINTRIN f64sse max(const f64sse& x, const f64sse& y) { return _mm_max_pd(*x, *y); }
KFR_SINTRIN u8sse max(const u8sse& x, const u8sse& y) { return _mm_max_epu8(*x, *y); }
KFR_SINTRIN i16sse max(const i16sse& x, const i16sse& y) { return _mm_max_epi16(*x, *y); }
-KFR_SINTRIN i64sse max(const i64sse& x, const i64sse& y) { return select(x > y, x, y); }
-KFR_SINTRIN u64sse max(const u64sse& x, const u64sse& y) { return select(x > y, x, y); }
#if defined CMT_ARCH_AVX2
KFR_SINTRIN u8avx min(const u8avx& x, const u8avx& y) { return _mm256_min_epu8(*x, *y); }
@@ -67,6 +63,35 @@ KFR_SINTRIN u16avx max(const u16avx& x, const u16avx& y) { return _mm256_max_epu
KFR_SINTRIN i32avx max(const i32avx& x, const i32avx& y) { return _mm256_max_epi32(*x, *y); }
KFR_SINTRIN u32avx max(const u32avx& x, const u32avx& y) { return _mm256_max_epu32(*x, *y); }
+#endif
+
+#if defined CMT_ARCH_AVX512
+KFR_SINTRIN u8avx512 min(const u8avx512& x, const u8avx512& y) { return _mm512_min_epu8(*x, *y); }
+KFR_SINTRIN i16avx512 min(const i16avx512& x, const i16avx512& y) { return _mm512_min_epi16(*x, *y); }
+KFR_SINTRIN i8avx512 min(const i8avx512& x, const i8avx512& y) { return _mm512_min_epi8(*x, *y); }
+KFR_SINTRIN u16avx512 min(const u16avx512& x, const u16avx512& y) { return _mm512_min_epu16(*x, *y); }
+KFR_SINTRIN i32avx512 min(const i32avx512& x, const i32avx512& y) { return _mm512_min_epi32(*x, *y); }
+KFR_SINTRIN u32avx512 min(const u32avx512& x, const u32avx512& y) { return _mm512_min_epu32(*x, *y); }
+KFR_SINTRIN u8avx512 max(const u8avx512& x, const u8avx512& y) { return _mm512_max_epu8(*x, *y); }
+KFR_SINTRIN i16avx512 max(const i16avx512& x, const i16avx512& y) { return _mm512_max_epi16(*x, *y); }
+KFR_SINTRIN i8avx512 max(const i8avx512& x, const i8avx512& y) { return _mm512_max_epi8(*x, *y); }
+KFR_SINTRIN u16avx512 max(const u16avx512& x, const u16avx512& y) { return _mm512_max_epu16(*x, *y); }
+KFR_SINTRIN i32avx512 max(const i32avx512& x, const i32avx512& y) { return _mm512_max_epi32(*x, *y); }
+KFR_SINTRIN u32avx512 max(const u32avx512& x, const u32avx512& y) { return _mm512_max_epu32(*x, *y); }
+KFR_SINTRIN i64avx512 min(const i64avx512& x, const i64avx512& y) { return _mm512_min_epi64(*x, *y); }
+KFR_SINTRIN u64avx512 min(const u64avx512& x, const u64avx512& y) { return _mm512_min_epu64(*x, *y); }
+KFR_SINTRIN i64avx512 max(const i64avx512& x, const i64avx512& y) { return _mm512_max_epi64(*x, *y); }
+KFR_SINTRIN u64avx512 max(const u64avx512& x, const u64avx512& y) { return _mm512_max_epu64(*x, *y); }
+
+KFR_SINTRIN i64avx min(const i64avx& x, const i64avx& y) { return _mm256_min_epi64(*x, *y); }
+KFR_SINTRIN u64avx min(const u64avx& x, const u64avx& y) { return _mm256_min_epu64(*x, *y); }
+KFR_SINTRIN i64avx max(const i64avx& x, const i64avx& y) { return _mm256_max_epi64(*x, *y); }
+KFR_SINTRIN u64avx max(const u64avx& x, const u64avx& y) { return _mm256_max_epu64(*x, *y); }
+#else
+KFR_SINTRIN i64sse min(const i64sse& x, const i64sse& y) { return select(x < y, x, y); }
+KFR_SINTRIN u64sse min(const u64sse& x, const u64sse& y) { return select(x < y, x, y); }
+KFR_SINTRIN i64sse max(const i64sse& x, const i64sse& y) { return select(x > y, x, y); }
+KFR_SINTRIN u64sse max(const u64sse& x, const u64sse& y) { return select(x > y, x, y); }
KFR_SINTRIN i64avx min(const i64avx& x, const i64avx& y) { return select(x < y, x, y); }
KFR_SINTRIN u64avx min(const u64avx& x, const u64avx& y) { return select(x < y, x, y); }
KFR_SINTRIN i64avx max(const i64avx& x, const i64avx& y) { return select(x > y, x, y); }
@@ -193,7 +218,7 @@ KFR_I_CONVERTER(min)
KFR_I_CONVERTER(max)
KFR_I_CONVERTER(absmin)
KFR_I_CONVERTER(absmax)
-}
+} // namespace intrinsics
KFR_I_FN(min)
KFR_I_FN(max)
KFR_I_FN(absmin)
@@ -274,4 +299,4 @@ KFR_INTRIN internal::expression_function<fn::absmax, E1, E2> absmax(E1&& x, E2&&
{
return { fn::absmax(), std::forward<E1>(x), std::forward<E2>(y) };
}
-}
+} // namespace kfr
diff --git a/include/kfr/base/round.hpp b/include/kfr/base/round.hpp
@@ -34,25 +34,32 @@ namespace kfr
namespace intrinsics
{
-#define KFR_mm_trunc_ps(V) _mm_round_ps((V), _MM_FROUND_TRUNC)
-#define KFR_mm_roundnearest_ps(V) _mm_round_ps((V), _MM_FROUND_NINT)
-#define KFR_mm_trunc_pd(V) _mm_round_pd((V), _MM_FROUND_TRUNC)
-#define KFR_mm_roundnearest_pd(V) _mm_round_pd((V), _MM_FROUND_NINT)
-
-#define KFR_mm_trunc_ss(V) _mm_round_ss(_mm_setzero_ps(), (V), _MM_FROUND_TRUNC)
-#define KFR_mm_roundnearest_ss(V) _mm_round_ss(_mm_setzero_ps(), (V), _MM_FROUND_NINT)
-#define KFR_mm_trunc_sd(V) _mm_round_sd(_mm_setzero_pd(), (V), _MM_FROUND_TRUNC)
-#define KFR_mm_roundnearest_sd(V) _mm_round_sd(_mm_setzero_pd(), (V), _MM_FROUND_NINT)
+#define KFR_mm_trunc_ps(V) _mm_round_ps((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm_roundnearest_ps(V) _mm_round_ps((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
+#define KFR_mm_trunc_pd(V) _mm_round_pd((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm_roundnearest_pd(V) _mm_round_pd((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
+
+#define KFR_mm_trunc_ss(V) _mm_round_ss(_mm_setzero_ps(), (V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm_roundnearest_ss(V) \
+ _mm_round_ss(_mm_setzero_ps(), (V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
+#define KFR_mm_trunc_sd(V) _mm_round_sd(_mm_setzero_pd(), (V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm_roundnearest_sd(V) \
+ _mm_round_sd(_mm_setzero_pd(), (V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
#define KFR_mm_floor_ss(V) _mm_floor_ss(_mm_setzero_ps(), (V))
#define KFR_mm_floor_sd(V) _mm_floor_sd(_mm_setzero_pd(), (V))
#define KFR_mm_ceil_ss(V) _mm_ceil_ss(_mm_setzero_ps(), (V))
#define KFR_mm_ceil_sd(V) _mm_ceil_sd(_mm_setzero_pd(), (V))
-#define KFR_mm256_trunc_ps(V) _mm256_round_ps((V), _MM_FROUND_TRUNC)
-#define KFR_mm256_roundnearest_ps(V) _mm256_round_ps((V), _MM_FROUND_NINT)
-#define KFR_mm256_trunc_pd(V) _mm256_round_pd((V), _MM_FROUND_TRUNC)
-#define KFR_mm256_roundnearest_pd(V) _mm256_round_pd((V), _MM_FROUND_NINT)
+#define KFR_mm256_trunc_ps(V) _mm256_round_ps((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm256_roundnearest_ps(V) _mm256_round_ps((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
+#define KFR_mm256_trunc_pd(V) _mm256_round_pd((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm256_roundnearest_pd(V) _mm256_round_pd((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
+
+#define KFR_mm512_trunc_ps(V) _mm512_roundscale_ps((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm512_roundnearest_ps(V) _mm512_roundscale_ps((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
+#define KFR_mm512_trunc_pd(V) _mm512_roundscale_pd((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm512_roundnearest_pd(V) _mm512_roundscale_pd((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
#if defined CMT_ARCH_SSE41 && defined KFR_NATIVE_INTRINSICS
@@ -81,6 +88,20 @@ KFR_SINTRIN f32avx fract(const f32avx& x) { return x - floor(x); }
KFR_SINTRIN f64avx fract(const f64avx& x) { return x - floor(x); }
#endif
+#if defined CMT_ARCH_AVX512
+
+KFR_SINTRIN f32avx512 floor(const f32avx512& value) { return _mm512_floor_ps(*value); }
+KFR_SINTRIN f32avx512 ceil(const f32avx512& value) { return _mm512_ceil_ps(*value); }
+KFR_SINTRIN f32avx512 trunc(const f32avx512& value) { return KFR_mm512_trunc_ps(*value); }
+KFR_SINTRIN f32avx512 round(const f32avx512& value) { return KFR_mm512_roundnearest_ps(*value); }
+KFR_SINTRIN f64avx512 floor(const f64avx512& value) { return _mm512_floor_pd(*value); }
+KFR_SINTRIN f64avx512 ceil(const f64avx512& value) { return _mm512_ceil_pd(*value); }
+KFR_SINTRIN f64avx512 trunc(const f64avx512& value) { return KFR_mm512_trunc_pd(*value); }
+KFR_SINTRIN f64avx512 round(const f64avx512& value) { return KFR_mm512_roundnearest_pd(*value); }
+KFR_SINTRIN f32avx512 fract(const f32avx512& x) { return x - floor(x); }
+KFR_SINTRIN f64avx512 fract(const f64avx512& x) { return x - floor(x); }
+#endif
+
KFR_HANDLE_ALL_SIZES_F_1(floor)
KFR_HANDLE_ALL_SIZES_F_1(ceil)
KFR_HANDLE_ALL_SIZES_F_1(round)
@@ -203,7 +224,7 @@ KFR_I_CONVERTER(ifloor)
KFR_I_CONVERTER(iceil)
KFR_I_CONVERTER(iround)
KFR_I_CONVERTER(itrunc)
-}
+} // namespace intrinsics
KFR_I_FN(floor)
KFR_I_FN(ceil)
KFR_I_FN(round)
@@ -339,7 +360,7 @@ CMT_INLINE vec<T, N> rem(const vec<T, N>& x, const vec<T, N>& y)
{
return fmod(x, y);
}
-}
+} // namespace kfr
#undef KFR_mm_trunc_ps
#undef KFR_mm_roundnearest_ps
diff --git a/include/kfr/base/saturation.hpp b/include/kfr/base/saturation.hpp
@@ -40,10 +40,10 @@ KFR_SINTRIN vec<T, N> saturated_signed_add(const vec<T, N>& a, const vec<T, N>&
{
using UT = utype<T>;
constexpr size_t shift = typebits<UT>::bits - 1;
- vec<UT, N> aa = bitcast<UT>(a);
- vec<UT, N> bb = bitcast<UT>(b);
- const vec<UT, N> sum = aa + bb;
- aa = (aa >> shift) + static_cast<UT>(std::numeric_limits<T>::max());
+ vec<UT, N> aa = bitcast<UT>(a);
+ vec<UT, N> bb = bitcast<UT>(b);
+ const vec<UT, N> sum = aa + bb;
+ aa = (aa >> shift) + static_cast<UT>(std::numeric_limits<T>::max());
return select(bitcast<T>((aa ^ bb) | ~(bb ^ sum)) >= 0, a, bitcast<T>(sum));
}
@@ -52,10 +52,10 @@ KFR_SINTRIN vec<T, N> saturated_signed_sub(const vec<T, N>& a, const vec<T, N>&
{
using UT = utype<T>;
constexpr size_t shift = typebits<UT>::bits - 1;
- vec<UT, N> aa = bitcast<UT>(a);
- vec<UT, N> bb = bitcast<UT>(b);
- const vec<UT, N> diff = aa - bb;
- aa = (aa >> shift) + static_cast<UT>(std::numeric_limits<T>::max());
+ vec<UT, N> aa = bitcast<UT>(a);
+ vec<UT, N> bb = bitcast<UT>(b);
+ const vec<UT, N> diff = aa - bb;
+ aa = (aa >> shift) + static_cast<UT>(std::numeric_limits<T>::max());
return select(bitcast<T>((aa ^ bb) & (aa ^ diff)) < 0, a, bitcast<T>(diff));
}
@@ -105,6 +105,17 @@ KFR_SINTRIN u16avx satsub(const u16avx& x, const u16avx& y) { return _mm256_subs
KFR_SINTRIN i16avx satsub(const i16avx& x, const i16avx& y) { return _mm256_subs_epi16(*x, *y); }
#endif
+#if defined CMT_ARCH_AVX512
+KFR_SINTRIN u8avx512 satadd(const u8avx512& x, const u8avx512& y) { return _mm512_adds_epu8(*x, *y); }
+KFR_SINTRIN i8avx512 satadd(const i8avx512& x, const i8avx512& y) { return _mm512_adds_epi8(*x, *y); }
+KFR_SINTRIN u16avx512 satadd(const u16avx512& x, const u16avx512& y) { return _mm512_adds_epu16(*x, *y); }
+KFR_SINTRIN i16avx512 satadd(const i16avx512& x, const i16avx512& y) { return _mm512_adds_epi16(*x, *y); }
+KFR_SINTRIN u8avx512 satsub(const u8avx512& x, const u8avx512& y) { return _mm512_subs_epu8(*x, *y); }
+KFR_SINTRIN i8avx512 satsub(const i8avx512& x, const i8avx512& y) { return _mm512_subs_epi8(*x, *y); }
+KFR_SINTRIN u16avx512 satsub(const u16avx512& x, const u16avx512& y) { return _mm512_subs_epu16(*x, *y); }
+KFR_SINTRIN i16avx512 satsub(const i16avx512& x, const i16avx512& y) { return _mm512_subs_epi16(*x, *y); }
+#endif
+
KFR_HANDLE_ALL_SIZES_2(satadd)
KFR_HANDLE_ALL_SIZES_2(satsub)
@@ -156,7 +167,7 @@ KFR_SINTRIN vec<T, N> satsub(const vec<T, N>& a, const vec<T, N>& b)
#endif
KFR_I_CONVERTER(satadd)
KFR_I_CONVERTER(satsub)
-}
+} // namespace intrinsics
KFR_I_FN(satadd)
KFR_I_FN(satsub)
@@ -189,4 +200,4 @@ KFR_INTRIN internal::expression_function<fn::satsub, E1, E2> satsub(E1&& x, E2&&
{
return { fn::satsub(), std::forward<E1>(x), std::forward<E2>(y) };
}
-}
+} // namespace kfr
diff --git a/include/kfr/base/select.hpp b/include/kfr/base/select.hpp
@@ -121,6 +121,49 @@ KFR_SINTRIN i64avx select(const maskfor<i64avx>& m, const i64avx& x, const i64av
}
#endif
+#if defined CMT_ARCH_AVX512
+KFR_SINTRIN f64avx512 select(const maskfor<f64avx512>& m, const f64avx512& x, const f64avx512& y)
+{
+ return _mm512_mask_blend_pd(_mm512_test_epi64_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN f32avx512 select(const maskfor<f32avx512>& m, const f32avx512& x, const f32avx512& y)
+{
+ return _mm512_mask_blend_ps(_mm512_test_epi32_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN u8avx512 select(const maskfor<u8avx512>& m, const u8avx512& x, const u8avx512& y)
+{
+ return _mm512_mask_blend_epi8(_mm512_test_epi8_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN u16avx512 select(const maskfor<u16avx512>& m, const u16avx512& x, const u16avx512& y)
+{
+ return _mm512_mask_blend_epi8(_mm512_test_epi16_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN u32avx512 select(const maskfor<u32avx512>& m, const u32avx512& x, const u32avx512& y)
+{
+ return _mm512_mask_blend_epi8(_mm512_test_epi32_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN u64avx512 select(const maskfor<u64avx512>& m, const u64avx512& x, const u64avx512& y)
+{
+ return _mm512_mask_blend_epi8(_mm512_test_epi64_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN i8avx512 select(const maskfor<i8avx512>& m, const i8avx512& x, const i8avx512& y)
+{
+ return _mm512_mask_blend_epi8(_mm512_test_epi8_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN i16avx512 select(const maskfor<i16avx512>& m, const i16avx512& x, const i16avx512& y)
+{
+ return _mm512_mask_blend_epi8(_mm512_test_epi16_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN i32avx512 select(const maskfor<i32avx512>& m, const i32avx512& x, const i32avx512& y)
+{
+ return _mm512_mask_blend_epi8(_mm512_test_epi32_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN i64avx512 select(const maskfor<i64avx512>& m, const i64avx512& x, const i64avx512& y)
+{
+ return _mm512_mask_blend_epi8(_mm512_test_epi64_mask(*m, *m), *y, *x);
+}
+#endif
+
template <typename T, size_t N, KFR_ENABLE_IF(N < platform<T>::vector_width)>
KFR_SINTRIN vec<T, N> select(const mask<T, N>& a, const vec<T, N>& b, const vec<T, N>& c)
{
@@ -211,7 +254,7 @@ KFR_SINTRIN vec<T, N> select(const vec<T, N>& m, const vec<T, N>& x, const vec<T
{
return select(m.asmask(), x, y);
}
-}
+} // namespace intrinsics
KFR_I_FN(select)
/**
@@ -238,4 +281,4 @@ KFR_INTRIN internal::expression_function<fn::select, E1, E2, E3> select(E1&& m,
{
return { fn::select(), std::forward<E1>(m), std::forward<E2>(x), std::forward<E3>(y) };
}
-}
+} // namespace kfr
diff --git a/include/kfr/base/simd_intrin.hpp b/include/kfr/base/simd_intrin.hpp
@@ -91,6 +91,19 @@ KFR_SIMD_SPEC_TYPE(f32, 8, __m256);
KFR_SIMD_SPEC_TYPE(f64, 4, __m256d);
#endif
+#ifdef CMT_ARCH_AVX512
+KFR_SIMD_SPEC_TYPE(u8, 64, __m512i);
+KFR_SIMD_SPEC_TYPE(u16, 32, __m512i);
+KFR_SIMD_SPEC_TYPE(u32, 16, __m512i);
+KFR_SIMD_SPEC_TYPE(u64, 8, __m512i);
+KFR_SIMD_SPEC_TYPE(i8, 64, __m512i);
+KFR_SIMD_SPEC_TYPE(i16, 32, __m512i);
+KFR_SIMD_SPEC_TYPE(i32, 16, __m512i);
+KFR_SIMD_SPEC_TYPE(i64, 8, __m512i);
+KFR_SIMD_SPEC_TYPE(f32, 16, __m512);
+KFR_SIMD_SPEC_TYPE(f64, 8, __m512d);
+#endif
+
#ifdef CMT_ARCH_NEON
KFR_SIMD_SPEC_TYPE(u8, 16, uint8x16_t);
KFR_SIMD_SPEC_TYPE(u16, 8, uint16x8_t);
@@ -118,17 +131,17 @@ struct raw_bytes
#define KFR_C_CYCLE(...) \
for (size_t i = 0; i < N; i++) \
- vs[i] = __VA_ARGS__
+ vs[i] = __VA_ARGS__
#define KFR_R_CYCLE(...) \
vec<T, N> result; \
- for (size_t i = 0; i < N; i++) \
+ for (size_t i = 0; i < N; i++) \
result.vs[i] = __VA_ARGS__; \
return result
#define KFR_B_CYCLE(...) \
vec<T, N> result; \
- for (size_t i = 0; i < N; i++) \
+ for (size_t i = 0; i < N; i++) \
result.vs[i] = (__VA_ARGS__) ? constants<value_type>::allones() : value_type(0); \
return result
@@ -282,13 +295,13 @@ struct alignas(const_min(platform<>::maximum_vector_alignment, sizeof(T) * next_
KFR_I_CE vec& operator++() noexcept { return *this = *this + vec(1); }
KFR_I_CE vec& operator--() noexcept { return *this = *this - vec(1); }
- KFR_I_CE vec operator++(int)noexcept
+ KFR_I_CE vec operator++(int) noexcept
{
const vec z = *this;
++*this;
return z;
}
- KFR_I_CE vec operator--(int)noexcept
+ KFR_I_CE vec operator--(int) noexcept
{
const vec z = *this;
--*this;
@@ -321,6 +334,7 @@ struct alignas(const_min(platform<>::maximum_vector_alignment, sizeof(T) * next_
const vec& flatten() const noexcept { return *this; }
simd_type operator*() const noexcept { return simd; }
simd_type& operator*() noexcept { return simd; }
+
protected:
template <typename, size_t>
friend struct vec;
@@ -366,13 +380,13 @@ CMT_INLINE vec<T, csum<size_t, N1, N2, Sizes...>()> concat_impl(const vec<T, N1>
{
return concat_impl(concat_impl(x, y), args...);
}
-}
+} // namespace internal
template <typename T, size_t... Ns>
constexpr inline vec<T, csum<size_t, Ns...>()> concat(const vec<T, Ns>&... vs) noexcept
{
return internal::concat_impl(vs...);
}
-}
+} // namespace kfr
CMT_PRAGMA_MSVC(warning(pop))
diff --git a/include/kfr/base/simd_x86.hpp b/include/kfr/base/simd_x86.hpp
@@ -181,4 +181,92 @@ KFR_I_CE CMT_INLINE vec<f64, 4> vec<f64, 4>::operator^(const vec<f64, 4>& y) con
#endif // CMT_ARCH_AVX
-} // namespace kf
+#ifdef CMT_ARCH_AVX512
+
+template <>
+KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator+(const vec<f32, 16>& y) const noexcept
+{
+ return _mm512_add_ps(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator-(const vec<f32, 16>& y) const noexcept
+{
+ return _mm512_sub_ps(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator*(const vec<f32, 16>& y) const noexcept
+{
+ return _mm512_mul_ps(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator/(const vec<f32, 16>& y) const noexcept
+{
+ return _mm512_div_ps(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator&(const vec<f32, 16>& y) const noexcept
+{
+ return _mm512_and_ps(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator|(const vec<f32, 16>& y) const noexcept
+{
+ return _mm512_or_ps(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator^(const vec<f32, 16>& y) const noexcept
+{
+ return _mm512_xor_ps(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator+(const vec<f64, 8>& y) const noexcept
+{
+ return _mm512_add_pd(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator-(const vec<f64, 8>& y) const noexcept
+{
+ return _mm512_sub_pd(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator*(const vec<f64, 8>& y) const noexcept
+{
+ return _mm512_mul_pd(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator/(const vec<f64, 8>& y) const noexcept
+{
+ return _mm512_div_pd(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator&(const vec<f64, 8>& y) const noexcept
+{
+ return _mm512_and_pd(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator|(const vec<f64, 8>& y) const noexcept
+{
+ return _mm512_or_pd(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator^(const vec<f64, 8>& y) const noexcept
+{
+ return _mm512_xor_pd(simd, y.simd);
+}
+
+#endif // CMT_ARCH_AVX
+
+} // namespace kfr
diff --git a/include/kfr/base/sqrt.hpp b/include/kfr/base/sqrt.hpp
@@ -48,6 +48,11 @@ KFR_SINTRIN f32avx sqrt(const f32avx& x) { return _mm256_sqrt_ps(*x); }
KFR_SINTRIN f64avx sqrt(const f64avx& x) { return _mm256_sqrt_pd(*x); }
#endif
+#if defined CMT_ARCH_AVX512
+KFR_SINTRIN f32avx512 sqrt(const f32avx512& x) { return _mm512_sqrt_ps(*x); }
+KFR_SINTRIN f64avx512 sqrt(const f64avx512& x) { return _mm512_sqrt_pd(*x); }
+#endif
+
KFR_HANDLE_ALL_SIZES_FLT_1(sqrt)
#else
diff --git a/include/kfr/base/types.hpp b/include/kfr/base/types.hpp
@@ -375,7 +375,7 @@ struct is_simd_type
template <typename T, size_t N>
struct vec_t
{
- static_assert(N > 0 && N <= 256, "Invalid vector size");
+ static_assert(N > 0 && N <= 1024, "Invalid vector size");
static_assert(is_simd_type<T>::value || !compound_type_traits<T>::is_scalar, "Invalid vector type");