kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

commit 956b4899cbfee4f0e5ab2725f0c61c1f42738ce1
parent c7cd3a9bcc73a10bbe0ba767d7bb6e0804ef8821
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date:   Tue, 13 Nov 2018 19:41:11 +0300

AVX512 support

Diffstat:
Minclude/kfr/base/abs.hpp | 15+++++++++++++--
Minclude/kfr/base/function.hpp | 15+++++++++++++--
Minclude/kfr/base/logical.hpp | 47++++++++++++++++++++++++++++++++++++++++-------
Minclude/kfr/base/min_max.hpp | 37+++++++++++++++++++++++++++++++------
Minclude/kfr/base/round.hpp | 51++++++++++++++++++++++++++++++++++++---------------
Minclude/kfr/base/saturation.hpp | 31+++++++++++++++++++++----------
Minclude/kfr/base/select.hpp | 47+++++++++++++++++++++++++++++++++++++++++++++--
Minclude/kfr/base/simd_intrin.hpp | 28+++++++++++++++++++++-------
Minclude/kfr/base/simd_x86.hpp | 90++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
Minclude/kfr/base/sqrt.hpp | 5+++++
Minclude/kfr/base/types.hpp | 2+-
11 files changed, 315 insertions(+), 53 deletions(-)

diff --git a/include/kfr/base/abs.hpp b/include/kfr/base/abs.hpp @@ -64,6 +64,17 @@ KFR_SINTRIN u16avx abs(const u16avx& x) { return x; } KFR_SINTRIN u8avx abs(const u8avx& x) { return x; } #endif +#if defined CMT_ARCH_AVX512 +KFR_SINTRIN i64avx512 abs(const i64avx512& x) { return select(x >= 0, x, -x); } +KFR_SINTRIN i32avx512 abs(const i32avx512& x) { return _mm512_abs_epi32(*x); } +KFR_SINTRIN i16avx512 abs(const i16avx512& x) { return _mm512_abs_epi16(*x); } +KFR_SINTRIN i8avx512 abs(const i8avx512& x) { return _mm512_abs_epi8(*x); } +KFR_SINTRIN u64avx512 abs(const u64avx512& x) { return x; } +KFR_SINTRIN u32avx512 abs(const u32avx512& x) { return x; } +KFR_SINTRIN u16avx512 abs(const u16avx512& x) { return x; } +KFR_SINTRIN u8avx512 abs(const u8avx512& x) { return x; } +#endif + KFR_HANDLE_ALL_SIZES_NOT_F_1(abs) #elif defined CMT_ARCH_NEON && defined KFR_NATIVE_INTRINSICS @@ -108,7 +119,7 @@ KFR_SINTRIN vec<T, N> abs(const vec<T, N>& x) } #endif KFR_I_CONVERTER(abs) -} +} // namespace intrinsics KFR_I_FN(abs) /** @@ -128,4 +139,4 @@ KFR_INTRIN internal::expression_function<fn::abs, E1> abs(E1&& x) { return { fn::abs(), std::forward<E1>(x) }; } -} +} // namespace kfr diff --git a/include/kfr/base/function.hpp b/include/kfr/base/function.hpp @@ -78,6 +78,17 @@ using u16avx = vec<u16, 16>; using u32avx = vec<u32, 8>; using u64avx = vec<u64, 4>; +using f32avx512 = vec<f32, 16>; +using f64avx512 = vec<f64, 8>; +using i8avx512 = vec<i8, 64>; +using i16avx512 = vec<i16, 32>; +using i32avx512 = vec<i32, 16>; +using i64avx512 = vec<i64, 8>; +using u8avx512 = vec<u8, 64>; +using u16avx512 = vec<u16, 32>; +using u32avx512 = vec<u32, 16>; +using u64avx512 = vec<u64, 8>; + #else using f32neon = vec<f32, 4>; using f64neon = vec<f64, 2>; @@ -252,6 +263,6 @@ inline T to_scalar(const vec<T, 1>& value) { return value[0]; } -} -} +} // namespace intrinsics +} // namespace kfr CMT_PRAGMA_GNU(GCC diagnostic pop) diff --git a/include/kfr/base/logical.hpp b/include/kfr/base/logical.hpp @@ -53,6 +53,7 @@ struct bitmask #if defined CMT_ARCH_SSE41 +// horizontal OR KFR_SINTRIN bool bittestany(const u8sse& x) { return !_mm_testz_si128(*x, *x); } KFR_SINTRIN bool bittestany(const u16sse& x) { return !_mm_testz_si128(*x, *x); } KFR_SINTRIN bool bittestany(const u32sse& x) { return !_mm_testz_si128(*x, *x); } @@ -62,6 +63,7 @@ KFR_SINTRIN bool bittestany(const i16sse& x) { return !_mm_testz_si128(*x, *x); KFR_SINTRIN bool bittestany(const i32sse& x) { return !_mm_testz_si128(*x, *x); } KFR_SINTRIN bool bittestany(const i64sse& x) { return !_mm_testz_si128(*x, *x); } +// horizontal AND KFR_SINTRIN bool bittestall(const u8sse& x) { return _mm_testc_si128(*x, *allonesvector(x)); } KFR_SINTRIN bool bittestall(const u16sse& x) { return _mm_testc_si128(*x, *allonesvector(x)); } KFR_SINTRIN bool bittestall(const u32sse& x) { return _mm_testc_si128(*x, *allonesvector(x)); } @@ -73,17 +75,13 @@ KFR_SINTRIN bool bittestall(const i64sse& x) { return _mm_testc_si128(*x, *allon #endif #if defined CMT_ARCH_AVX +// horizontal OR KFR_SINTRIN bool bittestany(const f32sse& x) { return !_mm_testz_ps(*x, *x); } KFR_SINTRIN bool bittestany(const f64sse& x) { return !_mm_testz_pd(*x, *x); } -KFR_SINTRIN bool bittestall(const f32sse& x) { return _mm_testc_ps(*x, *allonesvector(x)); } -KFR_SINTRIN bool bittestall(const f64sse& x) { return _mm_testc_pd(*x, *allonesvector(x)); } KFR_SINTRIN bool bittestany(const f32avx& x) { return !_mm256_testz_ps(*x, *x); } KFR_SINTRIN bool bittestany(const f64avx& x) { return !_mm256_testz_pd(*x, *x); } -KFR_SINTRIN bool bittestnall(const f32avx& x) { return _mm256_testc_ps(*x, *allonesvector(x)); } -KFR_SINTRIN bool bittestnall(const f64avx& x) { return _mm256_testc_pd(*x, *allonesvector(x)); } - KFR_SINTRIN bool bittestany(const u8avx& x) { return !_mm256_testz_si256(*x, *x); } KFR_SINTRIN bool bittestany(const u16avx& x) { return !_mm256_testz_si256(*x, *x); } KFR_SINTRIN bool bittestany(const u32avx& x) { return !_mm256_testz_si256(*x, *x); } @@ -93,6 +91,13 @@ KFR_SINTRIN bool bittestany(const i16avx& x) { return !_mm256_testz_si256(*x, *x KFR_SINTRIN bool bittestany(const i32avx& x) { return !_mm256_testz_si256(*x, *x); } KFR_SINTRIN bool bittestany(const i64avx& x) { return !_mm256_testz_si256(*x, *x); } +// horizontal AND +KFR_SINTRIN bool bittestall(const f32sse& x) { return _mm_testc_ps(*x, *allonesvector(x)); } +KFR_SINTRIN bool bittestall(const f64sse& x) { return _mm_testc_pd(*x, *allonesvector(x)); } + +KFR_SINTRIN bool bittestall(const f32avx& x) { return _mm256_testc_ps(*x, *allonesvector(x)); } +KFR_SINTRIN bool bittestall(const f64avx& x) { return _mm256_testc_pd(*x, *allonesvector(x)); } + KFR_SINTRIN bool bittestall(const u8avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); } KFR_SINTRIN bool bittestall(const u16avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); } KFR_SINTRIN bool bittestall(const u32avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); } @@ -101,6 +106,34 @@ KFR_SINTRIN bool bittestall(const i8avx& x) { return _mm256_testc_si256(*x, *all KFR_SINTRIN bool bittestall(const i16avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); } KFR_SINTRIN bool bittestall(const i32avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); } KFR_SINTRIN bool bittestall(const i64avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); } + +#if defined CMT_ARCH_AVX512 +// horizontal OR +KFR_SINTRIN bool bittestany(const f32avx512& x) { return !_mm512_test_epi32_mask(*x, *x); } +KFR_SINTRIN bool bittestany(const f64avx512& x) { return !_mm512_test_epi64_mask(*x, *x); } +KFR_SINTRIN bool bittestany(const u8avx512& x) { return !_mm512_test_epi8_mask(*x, *x); } +KFR_SINTRIN bool bittestany(const u16avx512& x) { return !_mm512_test_epi16_mask(*x, *x); } +KFR_SINTRIN bool bittestany(const u32avx512& x) { return !_mm512_test_epi32_mask(*x, *x); } +KFR_SINTRIN bool bittestany(const u64avx512& x) { return !_mm512_test_epi64_mask(*x, *x); } +KFR_SINTRIN bool bittestany(const i8avx512& x) { return !_mm512_test_epi8_mask(*x, *x); } +KFR_SINTRIN bool bittestany(const i16avx512& x) { return !_mm512_test_epi16_mask(*x, *x); } +KFR_SINTRIN bool bittestany(const i32avx512& x) { return !_mm512_test_epi32_mask(*x, *x); } +KFR_SINTRIN bool bittestany(const i64avx512& x) { return !_mm512_test_epi64_mask(*x, *x); } + +// horizontal AND +KFR_SINTRIN bool bittestall(const f32avx512& x) { return ~bittestany(~x); } +KFR_SINTRIN bool bittestall(const f64avx512& x) { return ~bittestany(~x); } +KFR_SINTRIN bool bittestall(const u8avx512& x) { return ~bittestany(~x); } +KFR_SINTRIN bool bittestall(const u16avx512& x) { return ~bittestany(~x); } +KFR_SINTRIN bool bittestall(const u32avx512& x) { return ~bittestany(~x); } +KFR_SINTRIN bool bittestall(const u64avx512& x) { return ~bittestany(~x); } +KFR_SINTRIN bool bittestall(const i8avx512& x) { return ~bittestany(~x); } +KFR_SINTRIN bool bittestall(const i16avx512& x) { return ~bittestany(~x); } +KFR_SINTRIN bool bittestall(const i32avx512& x) { return ~bittestany(~x); } +KFR_SINTRIN bool bittestall(const i64avx512& x) { return ~bittestany(~x); } + +#endif + #elif defined CMT_ARCH_SSE41 KFR_SINTRIN bool bittestany(const f32sse& x) { return !_mm_testz_si128(*bitcast<u8>(x), *bitcast<u8>(x)); } KFR_SINTRIN bool bittestany(const f64sse& x) { return !_mm_testz_si128(*bitcast<u8>(x), *bitcast<u8>(x)); } @@ -249,7 +282,7 @@ KFR_SINTRIN bool bittestall(const vec<T, N>& x, const vec<T, N>& y) return !bittestany(~x & y); } #endif -} +} // namespace intrinsics /** * @brief Returns x[0] && x[1] && ... && x[N-1] @@ -268,4 +301,4 @@ KFR_SINTRIN bool any(const mask<T, N>& x) { return intrinsics::bittestany(x.asvec()); } -} +} // namespace kfr diff --git a/include/kfr/base/min_max.hpp b/include/kfr/base/min_max.hpp @@ -42,15 +42,11 @@ KFR_SINTRIN f32sse min(const f32sse& x, const f32sse& y) { return _mm_min_ps(*x, KFR_SINTRIN f64sse min(const f64sse& x, const f64sse& y) { return _mm_min_pd(*x, *y); } KFR_SINTRIN u8sse min(const u8sse& x, const u8sse& y) { return _mm_min_epu8(*x, *y); } KFR_SINTRIN i16sse min(const i16sse& x, const i16sse& y) { return _mm_min_epi16(*x, *y); } -KFR_SINTRIN i64sse min(const i64sse& x, const i64sse& y) { return select(x < y, x, y); } -KFR_SINTRIN u64sse min(const u64sse& x, const u64sse& y) { return select(x < y, x, y); } KFR_SINTRIN f32sse max(const f32sse& x, const f32sse& y) { return _mm_max_ps(*x, *y); } KFR_SINTRIN f64sse max(const f64sse& x, const f64sse& y) { return _mm_max_pd(*x, *y); } KFR_SINTRIN u8sse max(const u8sse& x, const u8sse& y) { return _mm_max_epu8(*x, *y); } KFR_SINTRIN i16sse max(const i16sse& x, const i16sse& y) { return _mm_max_epi16(*x, *y); } -KFR_SINTRIN i64sse max(const i64sse& x, const i64sse& y) { return select(x > y, x, y); } -KFR_SINTRIN u64sse max(const u64sse& x, const u64sse& y) { return select(x > y, x, y); } #if defined CMT_ARCH_AVX2 KFR_SINTRIN u8avx min(const u8avx& x, const u8avx& y) { return _mm256_min_epu8(*x, *y); } @@ -67,6 +63,35 @@ KFR_SINTRIN u16avx max(const u16avx& x, const u16avx& y) { return _mm256_max_epu KFR_SINTRIN i32avx max(const i32avx& x, const i32avx& y) { return _mm256_max_epi32(*x, *y); } KFR_SINTRIN u32avx max(const u32avx& x, const u32avx& y) { return _mm256_max_epu32(*x, *y); } +#endif + +#if defined CMT_ARCH_AVX512 +KFR_SINTRIN u8avx512 min(const u8avx512& x, const u8avx512& y) { return _mm512_min_epu8(*x, *y); } +KFR_SINTRIN i16avx512 min(const i16avx512& x, const i16avx512& y) { return _mm512_min_epi16(*x, *y); } +KFR_SINTRIN i8avx512 min(const i8avx512& x, const i8avx512& y) { return _mm512_min_epi8(*x, *y); } +KFR_SINTRIN u16avx512 min(const u16avx512& x, const u16avx512& y) { return _mm512_min_epu16(*x, *y); } +KFR_SINTRIN i32avx512 min(const i32avx512& x, const i32avx512& y) { return _mm512_min_epi32(*x, *y); } +KFR_SINTRIN u32avx512 min(const u32avx512& x, const u32avx512& y) { return _mm512_min_epu32(*x, *y); } +KFR_SINTRIN u8avx512 max(const u8avx512& x, const u8avx512& y) { return _mm512_max_epu8(*x, *y); } +KFR_SINTRIN i16avx512 max(const i16avx512& x, const i16avx512& y) { return _mm512_max_epi16(*x, *y); } +KFR_SINTRIN i8avx512 max(const i8avx512& x, const i8avx512& y) { return _mm512_max_epi8(*x, *y); } +KFR_SINTRIN u16avx512 max(const u16avx512& x, const u16avx512& y) { return _mm512_max_epu16(*x, *y); } +KFR_SINTRIN i32avx512 max(const i32avx512& x, const i32avx512& y) { return _mm512_max_epi32(*x, *y); } +KFR_SINTRIN u32avx512 max(const u32avx512& x, const u32avx512& y) { return _mm512_max_epu32(*x, *y); } +KFR_SINTRIN i64avx512 min(const i64avx512& x, const i64avx512& y) { return _mm512_min_epi64(*x, *y); } +KFR_SINTRIN u64avx512 min(const u64avx512& x, const u64avx512& y) { return _mm512_min_epu64(*x, *y); } +KFR_SINTRIN i64avx512 max(const i64avx512& x, const i64avx512& y) { return _mm512_max_epi64(*x, *y); } +KFR_SINTRIN u64avx512 max(const u64avx512& x, const u64avx512& y) { return _mm512_max_epu64(*x, *y); } + +KFR_SINTRIN i64avx min(const i64avx& x, const i64avx& y) { return _mm256_min_epi64(*x, *y); } +KFR_SINTRIN u64avx min(const u64avx& x, const u64avx& y) { return _mm256_min_epu64(*x, *y); } +KFR_SINTRIN i64avx max(const i64avx& x, const i64avx& y) { return _mm256_max_epi64(*x, *y); } +KFR_SINTRIN u64avx max(const u64avx& x, const u64avx& y) { return _mm256_max_epu64(*x, *y); } +#else +KFR_SINTRIN i64sse min(const i64sse& x, const i64sse& y) { return select(x < y, x, y); } +KFR_SINTRIN u64sse min(const u64sse& x, const u64sse& y) { return select(x < y, x, y); } +KFR_SINTRIN i64sse max(const i64sse& x, const i64sse& y) { return select(x > y, x, y); } +KFR_SINTRIN u64sse max(const u64sse& x, const u64sse& y) { return select(x > y, x, y); } KFR_SINTRIN i64avx min(const i64avx& x, const i64avx& y) { return select(x < y, x, y); } KFR_SINTRIN u64avx min(const u64avx& x, const u64avx& y) { return select(x < y, x, y); } KFR_SINTRIN i64avx max(const i64avx& x, const i64avx& y) { return select(x > y, x, y); } @@ -193,7 +218,7 @@ KFR_I_CONVERTER(min) KFR_I_CONVERTER(max) KFR_I_CONVERTER(absmin) KFR_I_CONVERTER(absmax) -} +} // namespace intrinsics KFR_I_FN(min) KFR_I_FN(max) KFR_I_FN(absmin) @@ -274,4 +299,4 @@ KFR_INTRIN internal::expression_function<fn::absmax, E1, E2> absmax(E1&& x, E2&& { return { fn::absmax(), std::forward<E1>(x), std::forward<E2>(y) }; } -} +} // namespace kfr diff --git a/include/kfr/base/round.hpp b/include/kfr/base/round.hpp @@ -34,25 +34,32 @@ namespace kfr namespace intrinsics { -#define KFR_mm_trunc_ps(V) _mm_round_ps((V), _MM_FROUND_TRUNC) -#define KFR_mm_roundnearest_ps(V) _mm_round_ps((V), _MM_FROUND_NINT) -#define KFR_mm_trunc_pd(V) _mm_round_pd((V), _MM_FROUND_TRUNC) -#define KFR_mm_roundnearest_pd(V) _mm_round_pd((V), _MM_FROUND_NINT) - -#define KFR_mm_trunc_ss(V) _mm_round_ss(_mm_setzero_ps(), (V), _MM_FROUND_TRUNC) -#define KFR_mm_roundnearest_ss(V) _mm_round_ss(_mm_setzero_ps(), (V), _MM_FROUND_NINT) -#define KFR_mm_trunc_sd(V) _mm_round_sd(_mm_setzero_pd(), (V), _MM_FROUND_TRUNC) -#define KFR_mm_roundnearest_sd(V) _mm_round_sd(_mm_setzero_pd(), (V), _MM_FROUND_NINT) +#define KFR_mm_trunc_ps(V) _mm_round_ps((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC) +#define KFR_mm_roundnearest_ps(V) _mm_round_ps((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) +#define KFR_mm_trunc_pd(V) _mm_round_pd((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC) +#define KFR_mm_roundnearest_pd(V) _mm_round_pd((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) + +#define KFR_mm_trunc_ss(V) _mm_round_ss(_mm_setzero_ps(), (V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC) +#define KFR_mm_roundnearest_ss(V) \ + _mm_round_ss(_mm_setzero_ps(), (V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) +#define KFR_mm_trunc_sd(V) _mm_round_sd(_mm_setzero_pd(), (V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC) +#define KFR_mm_roundnearest_sd(V) \ + _mm_round_sd(_mm_setzero_pd(), (V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) #define KFR_mm_floor_ss(V) _mm_floor_ss(_mm_setzero_ps(), (V)) #define KFR_mm_floor_sd(V) _mm_floor_sd(_mm_setzero_pd(), (V)) #define KFR_mm_ceil_ss(V) _mm_ceil_ss(_mm_setzero_ps(), (V)) #define KFR_mm_ceil_sd(V) _mm_ceil_sd(_mm_setzero_pd(), (V)) -#define KFR_mm256_trunc_ps(V) _mm256_round_ps((V), _MM_FROUND_TRUNC) -#define KFR_mm256_roundnearest_ps(V) _mm256_round_ps((V), _MM_FROUND_NINT) -#define KFR_mm256_trunc_pd(V) _mm256_round_pd((V), _MM_FROUND_TRUNC) -#define KFR_mm256_roundnearest_pd(V) _mm256_round_pd((V), _MM_FROUND_NINT) +#define KFR_mm256_trunc_ps(V) _mm256_round_ps((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC) +#define KFR_mm256_roundnearest_ps(V) _mm256_round_ps((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) +#define KFR_mm256_trunc_pd(V) _mm256_round_pd((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC) +#define KFR_mm256_roundnearest_pd(V) _mm256_round_pd((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) + +#define KFR_mm512_trunc_ps(V) _mm512_roundscale_ps((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC) +#define KFR_mm512_roundnearest_ps(V) _mm512_roundscale_ps((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) +#define KFR_mm512_trunc_pd(V) _mm512_roundscale_pd((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC) +#define KFR_mm512_roundnearest_pd(V) _mm512_roundscale_pd((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) #if defined CMT_ARCH_SSE41 && defined KFR_NATIVE_INTRINSICS @@ -81,6 +88,20 @@ KFR_SINTRIN f32avx fract(const f32avx& x) { return x - floor(x); } KFR_SINTRIN f64avx fract(const f64avx& x) { return x - floor(x); } #endif +#if defined CMT_ARCH_AVX512 + +KFR_SINTRIN f32avx512 floor(const f32avx512& value) { return _mm512_floor_ps(*value); } +KFR_SINTRIN f32avx512 ceil(const f32avx512& value) { return _mm512_ceil_ps(*value); } +KFR_SINTRIN f32avx512 trunc(const f32avx512& value) { return KFR_mm512_trunc_ps(*value); } +KFR_SINTRIN f32avx512 round(const f32avx512& value) { return KFR_mm512_roundnearest_ps(*value); } +KFR_SINTRIN f64avx512 floor(const f64avx512& value) { return _mm512_floor_pd(*value); } +KFR_SINTRIN f64avx512 ceil(const f64avx512& value) { return _mm512_ceil_pd(*value); } +KFR_SINTRIN f64avx512 trunc(const f64avx512& value) { return KFR_mm512_trunc_pd(*value); } +KFR_SINTRIN f64avx512 round(const f64avx512& value) { return KFR_mm512_roundnearest_pd(*value); } +KFR_SINTRIN f32avx512 fract(const f32avx512& x) { return x - floor(x); } +KFR_SINTRIN f64avx512 fract(const f64avx512& x) { return x - floor(x); } +#endif + KFR_HANDLE_ALL_SIZES_F_1(floor) KFR_HANDLE_ALL_SIZES_F_1(ceil) KFR_HANDLE_ALL_SIZES_F_1(round) @@ -203,7 +224,7 @@ KFR_I_CONVERTER(ifloor) KFR_I_CONVERTER(iceil) KFR_I_CONVERTER(iround) KFR_I_CONVERTER(itrunc) -} +} // namespace intrinsics KFR_I_FN(floor) KFR_I_FN(ceil) KFR_I_FN(round) @@ -339,7 +360,7 @@ CMT_INLINE vec<T, N> rem(const vec<T, N>& x, const vec<T, N>& y) { return fmod(x, y); } -} +} // namespace kfr #undef KFR_mm_trunc_ps #undef KFR_mm_roundnearest_ps diff --git a/include/kfr/base/saturation.hpp b/include/kfr/base/saturation.hpp @@ -40,10 +40,10 @@ KFR_SINTRIN vec<T, N> saturated_signed_add(const vec<T, N>& a, const vec<T, N>& { using UT = utype<T>; constexpr size_t shift = typebits<UT>::bits - 1; - vec<UT, N> aa = bitcast<UT>(a); - vec<UT, N> bb = bitcast<UT>(b); - const vec<UT, N> sum = aa + bb; - aa = (aa >> shift) + static_cast<UT>(std::numeric_limits<T>::max()); + vec<UT, N> aa = bitcast<UT>(a); + vec<UT, N> bb = bitcast<UT>(b); + const vec<UT, N> sum = aa + bb; + aa = (aa >> shift) + static_cast<UT>(std::numeric_limits<T>::max()); return select(bitcast<T>((aa ^ bb) | ~(bb ^ sum)) >= 0, a, bitcast<T>(sum)); } @@ -52,10 +52,10 @@ KFR_SINTRIN vec<T, N> saturated_signed_sub(const vec<T, N>& a, const vec<T, N>& { using UT = utype<T>; constexpr size_t shift = typebits<UT>::bits - 1; - vec<UT, N> aa = bitcast<UT>(a); - vec<UT, N> bb = bitcast<UT>(b); - const vec<UT, N> diff = aa - bb; - aa = (aa >> shift) + static_cast<UT>(std::numeric_limits<T>::max()); + vec<UT, N> aa = bitcast<UT>(a); + vec<UT, N> bb = bitcast<UT>(b); + const vec<UT, N> diff = aa - bb; + aa = (aa >> shift) + static_cast<UT>(std::numeric_limits<T>::max()); return select(bitcast<T>((aa ^ bb) & (aa ^ diff)) < 0, a, bitcast<T>(diff)); } @@ -105,6 +105,17 @@ KFR_SINTRIN u16avx satsub(const u16avx& x, const u16avx& y) { return _mm256_subs KFR_SINTRIN i16avx satsub(const i16avx& x, const i16avx& y) { return _mm256_subs_epi16(*x, *y); } #endif +#if defined CMT_ARCH_AVX512 +KFR_SINTRIN u8avx512 satadd(const u8avx512& x, const u8avx512& y) { return _mm512_adds_epu8(*x, *y); } +KFR_SINTRIN i8avx512 satadd(const i8avx512& x, const i8avx512& y) { return _mm512_adds_epi8(*x, *y); } +KFR_SINTRIN u16avx512 satadd(const u16avx512& x, const u16avx512& y) { return _mm512_adds_epu16(*x, *y); } +KFR_SINTRIN i16avx512 satadd(const i16avx512& x, const i16avx512& y) { return _mm512_adds_epi16(*x, *y); } +KFR_SINTRIN u8avx512 satsub(const u8avx512& x, const u8avx512& y) { return _mm512_subs_epu8(*x, *y); } +KFR_SINTRIN i8avx512 satsub(const i8avx512& x, const i8avx512& y) { return _mm512_subs_epi8(*x, *y); } +KFR_SINTRIN u16avx512 satsub(const u16avx512& x, const u16avx512& y) { return _mm512_subs_epu16(*x, *y); } +KFR_SINTRIN i16avx512 satsub(const i16avx512& x, const i16avx512& y) { return _mm512_subs_epi16(*x, *y); } +#endif + KFR_HANDLE_ALL_SIZES_2(satadd) KFR_HANDLE_ALL_SIZES_2(satsub) @@ -156,7 +167,7 @@ KFR_SINTRIN vec<T, N> satsub(const vec<T, N>& a, const vec<T, N>& b) #endif KFR_I_CONVERTER(satadd) KFR_I_CONVERTER(satsub) -} +} // namespace intrinsics KFR_I_FN(satadd) KFR_I_FN(satsub) @@ -189,4 +200,4 @@ KFR_INTRIN internal::expression_function<fn::satsub, E1, E2> satsub(E1&& x, E2&& { return { fn::satsub(), std::forward<E1>(x), std::forward<E2>(y) }; } -} +} // namespace kfr diff --git a/include/kfr/base/select.hpp b/include/kfr/base/select.hpp @@ -121,6 +121,49 @@ KFR_SINTRIN i64avx select(const maskfor<i64avx>& m, const i64avx& x, const i64av } #endif +#if defined CMT_ARCH_AVX512 +KFR_SINTRIN f64avx512 select(const maskfor<f64avx512>& m, const f64avx512& x, const f64avx512& y) +{ + return _mm512_mask_blend_pd(_mm512_test_epi64_mask(*m, *m), *y, *x); +} +KFR_SINTRIN f32avx512 select(const maskfor<f32avx512>& m, const f32avx512& x, const f32avx512& y) +{ + return _mm512_mask_blend_ps(_mm512_test_epi32_mask(*m, *m), *y, *x); +} +KFR_SINTRIN u8avx512 select(const maskfor<u8avx512>& m, const u8avx512& x, const u8avx512& y) +{ + return _mm512_mask_blend_epi8(_mm512_test_epi8_mask(*m, *m), *y, *x); +} +KFR_SINTRIN u16avx512 select(const maskfor<u16avx512>& m, const u16avx512& x, const u16avx512& y) +{ + return _mm512_mask_blend_epi8(_mm512_test_epi16_mask(*m, *m), *y, *x); +} +KFR_SINTRIN u32avx512 select(const maskfor<u32avx512>& m, const u32avx512& x, const u32avx512& y) +{ + return _mm512_mask_blend_epi8(_mm512_test_epi32_mask(*m, *m), *y, *x); +} +KFR_SINTRIN u64avx512 select(const maskfor<u64avx512>& m, const u64avx512& x, const u64avx512& y) +{ + return _mm512_mask_blend_epi8(_mm512_test_epi64_mask(*m, *m), *y, *x); +} +KFR_SINTRIN i8avx512 select(const maskfor<i8avx512>& m, const i8avx512& x, const i8avx512& y) +{ + return _mm512_mask_blend_epi8(_mm512_test_epi8_mask(*m, *m), *y, *x); +} +KFR_SINTRIN i16avx512 select(const maskfor<i16avx512>& m, const i16avx512& x, const i16avx512& y) +{ + return _mm512_mask_blend_epi8(_mm512_test_epi16_mask(*m, *m), *y, *x); +} +KFR_SINTRIN i32avx512 select(const maskfor<i32avx512>& m, const i32avx512& x, const i32avx512& y) +{ + return _mm512_mask_blend_epi8(_mm512_test_epi32_mask(*m, *m), *y, *x); +} +KFR_SINTRIN i64avx512 select(const maskfor<i64avx512>& m, const i64avx512& x, const i64avx512& y) +{ + return _mm512_mask_blend_epi8(_mm512_test_epi64_mask(*m, *m), *y, *x); +} +#endif + template <typename T, size_t N, KFR_ENABLE_IF(N < platform<T>::vector_width)> KFR_SINTRIN vec<T, N> select(const mask<T, N>& a, const vec<T, N>& b, const vec<T, N>& c) { @@ -211,7 +254,7 @@ KFR_SINTRIN vec<T, N> select(const vec<T, N>& m, const vec<T, N>& x, const vec<T { return select(m.asmask(), x, y); } -} +} // namespace intrinsics KFR_I_FN(select) /** @@ -238,4 +281,4 @@ KFR_INTRIN internal::expression_function<fn::select, E1, E2, E3> select(E1&& m, { return { fn::select(), std::forward<E1>(m), std::forward<E2>(x), std::forward<E3>(y) }; } -} +} // namespace kfr diff --git a/include/kfr/base/simd_intrin.hpp b/include/kfr/base/simd_intrin.hpp @@ -91,6 +91,19 @@ KFR_SIMD_SPEC_TYPE(f32, 8, __m256); KFR_SIMD_SPEC_TYPE(f64, 4, __m256d); #endif +#ifdef CMT_ARCH_AVX512 +KFR_SIMD_SPEC_TYPE(u8, 64, __m512i); +KFR_SIMD_SPEC_TYPE(u16, 32, __m512i); +KFR_SIMD_SPEC_TYPE(u32, 16, __m512i); +KFR_SIMD_SPEC_TYPE(u64, 8, __m512i); +KFR_SIMD_SPEC_TYPE(i8, 64, __m512i); +KFR_SIMD_SPEC_TYPE(i16, 32, __m512i); +KFR_SIMD_SPEC_TYPE(i32, 16, __m512i); +KFR_SIMD_SPEC_TYPE(i64, 8, __m512i); +KFR_SIMD_SPEC_TYPE(f32, 16, __m512); +KFR_SIMD_SPEC_TYPE(f64, 8, __m512d); +#endif + #ifdef CMT_ARCH_NEON KFR_SIMD_SPEC_TYPE(u8, 16, uint8x16_t); KFR_SIMD_SPEC_TYPE(u16, 8, uint16x8_t); @@ -118,17 +131,17 @@ struct raw_bytes #define KFR_C_CYCLE(...) \ for (size_t i = 0; i < N; i++) \ - vs[i] = __VA_ARGS__ + vs[i] = __VA_ARGS__ #define KFR_R_CYCLE(...) \ vec<T, N> result; \ - for (size_t i = 0; i < N; i++) \ + for (size_t i = 0; i < N; i++) \ result.vs[i] = __VA_ARGS__; \ return result #define KFR_B_CYCLE(...) \ vec<T, N> result; \ - for (size_t i = 0; i < N; i++) \ + for (size_t i = 0; i < N; i++) \ result.vs[i] = (__VA_ARGS__) ? constants<value_type>::allones() : value_type(0); \ return result @@ -282,13 +295,13 @@ struct alignas(const_min(platform<>::maximum_vector_alignment, sizeof(T) * next_ KFR_I_CE vec& operator++() noexcept { return *this = *this + vec(1); } KFR_I_CE vec& operator--() noexcept { return *this = *this - vec(1); } - KFR_I_CE vec operator++(int)noexcept + KFR_I_CE vec operator++(int) noexcept { const vec z = *this; ++*this; return z; } - KFR_I_CE vec operator--(int)noexcept + KFR_I_CE vec operator--(int) noexcept { const vec z = *this; --*this; @@ -321,6 +334,7 @@ struct alignas(const_min(platform<>::maximum_vector_alignment, sizeof(T) * next_ const vec& flatten() const noexcept { return *this; } simd_type operator*() const noexcept { return simd; } simd_type& operator*() noexcept { return simd; } + protected: template <typename, size_t> friend struct vec; @@ -366,13 +380,13 @@ CMT_INLINE vec<T, csum<size_t, N1, N2, Sizes...>()> concat_impl(const vec<T, N1> { return concat_impl(concat_impl(x, y), args...); } -} +} // namespace internal template <typename T, size_t... Ns> constexpr inline vec<T, csum<size_t, Ns...>()> concat(const vec<T, Ns>&... vs) noexcept { return internal::concat_impl(vs...); } -} +} // namespace kfr CMT_PRAGMA_MSVC(warning(pop)) diff --git a/include/kfr/base/simd_x86.hpp b/include/kfr/base/simd_x86.hpp @@ -181,4 +181,92 @@ KFR_I_CE CMT_INLINE vec<f64, 4> vec<f64, 4>::operator^(const vec<f64, 4>& y) con #endif // CMT_ARCH_AVX -} // namespace kf +#ifdef CMT_ARCH_AVX512 + +template <> +KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator+(const vec<f32, 16>& y) const noexcept +{ + return _mm512_add_ps(simd, y.simd); +} + +template <> +KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator-(const vec<f32, 16>& y) const noexcept +{ + return _mm512_sub_ps(simd, y.simd); +} + +template <> +KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator*(const vec<f32, 16>& y) const noexcept +{ + return _mm512_mul_ps(simd, y.simd); +} + +template <> +KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator/(const vec<f32, 16>& y) const noexcept +{ + return _mm512_div_ps(simd, y.simd); +} + +template <> +KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator&(const vec<f32, 16>& y) const noexcept +{ + return _mm512_and_ps(simd, y.simd); +} + +template <> +KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator|(const vec<f32, 16>& y) const noexcept +{ + return _mm512_or_ps(simd, y.simd); +} + +template <> +KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator^(const vec<f32, 16>& y) const noexcept +{ + return _mm512_xor_ps(simd, y.simd); +} + +template <> +KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator+(const vec<f64, 8>& y) const noexcept +{ + return _mm512_add_pd(simd, y.simd); +} + +template <> +KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator-(const vec<f64, 8>& y) const noexcept +{ + return _mm512_sub_pd(simd, y.simd); +} + +template <> +KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator*(const vec<f64, 8>& y) const noexcept +{ + return _mm512_mul_pd(simd, y.simd); +} + +template <> +KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator/(const vec<f64, 8>& y) const noexcept +{ + return _mm512_div_pd(simd, y.simd); +} + +template <> +KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator&(const vec<f64, 8>& y) const noexcept +{ + return _mm512_and_pd(simd, y.simd); +} + +template <> +KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator|(const vec<f64, 8>& y) const noexcept +{ + return _mm512_or_pd(simd, y.simd); +} + +template <> +KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator^(const vec<f64, 8>& y) const noexcept +{ + return _mm512_xor_pd(simd, y.simd); +} + +#endif // CMT_ARCH_AVX + +} // namespace kfr diff --git a/include/kfr/base/sqrt.hpp b/include/kfr/base/sqrt.hpp @@ -48,6 +48,11 @@ KFR_SINTRIN f32avx sqrt(const f32avx& x) { return _mm256_sqrt_ps(*x); } KFR_SINTRIN f64avx sqrt(const f64avx& x) { return _mm256_sqrt_pd(*x); } #endif +#if defined CMT_ARCH_AVX512 +KFR_SINTRIN f32avx512 sqrt(const f32avx512& x) { return _mm512_sqrt_ps(*x); } +KFR_SINTRIN f64avx512 sqrt(const f64avx512& x) { return _mm512_sqrt_pd(*x); } +#endif + KFR_HANDLE_ALL_SIZES_FLT_1(sqrt) #else diff --git a/include/kfr/base/types.hpp b/include/kfr/base/types.hpp @@ -375,7 +375,7 @@ struct is_simd_type template <typename T, size_t N> struct vec_t { - static_assert(N > 0 && N <= 256, "Invalid vector size"); + static_assert(N > 0 && N <= 1024, "Invalid vector size"); static_assert(is_simd_type<T>::value || !compound_type_traits<T>::is_scalar, "Invalid vector type");