commit 2e9ef22bc9777f665c067ce25f99e5a02ca64d32
parent 9d34f60f1bdfbf0b52477248e286f1dccf4d900a
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Fri, 16 Nov 2018 04:42:17 +0300
Merge branch 'ci' into 3.0
Diffstat:
21 files changed, 594 insertions(+), 246 deletions(-)
diff --git a/.travis.yml b/.travis.yml
@@ -1,61 +0,0 @@
-language: cpp
-matrix:
- include:
- - os: linux
- compiler: clang
- sudo: required
- dist: trusty
- addons:
- apt:
- sources:
- - ubuntu-toolchain-r-test
- - llvm-toolchain-precise-3.8
- packages:
- - g++-5
- - clang-3.8
- - libmpfr-dev
- env:
- - TEST=LINUX-X86-64 CMAKEARGS="-DCMAKE_CXX_COMPILER=clang++-3.8 -DCMAKE_BUILD_TYPE=Release .."
- - os: linux
- compiler: clang
- sudo: required
- dist: trusty
- addons:
- apt:
- sources:
- - ubuntu-toolchain-r-test
- - llvm-toolchain-precise-3.8
- packages:
- - clang-3.8
- - qemu
- - g++-arm-linux-gnueabihf
- env:
- - TEST=LINUX-ARMV7 CMAKEARGS="-DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/arm.cmake -DARCH_FLAGS=-DLIBC_WORKAROUND_GETS=1 .."
- - os: osx
- osx_image: xcode8
- env:
- - TEST=XCODE8 CMAKEARGS="-DCMAKE_BUILD_TYPE=Release .."
- - os: osx
- osx_image: xcode9.1
- env:
- - TEST=XCODE9.1 CMAKEARGS="-DCMAKE_BUILD_TYPE=Release .."
- - os: osx
- osx_image: xcode9.4
- env:
- - TEST=XCODE9.4 CMAKEARGS="-DCMAKE_BUILD_TYPE=Release .."
- - os: osx
- osx_image: xcode10
- env:
- - TEST=XCODE10.0 CMAKEARGS="-DCMAKE_BUILD_TYPE=Release .."
-
-before_install:
- - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then cmake --version || brew install cmake; fi
- - cmake --version
-
-script:
- - mkdir build
- - cd build
- - cmake $CMAKEARGS
- - make -j4
- - cd tests
- - ctest -V
diff --git a/CMakeLists.txt b/CMakeLists.txt
@@ -36,6 +36,13 @@ if(NOT CMAKE_CXX_COMPILER)
# If clang is not found, leave default compiler (usually GCC)
if(NOT DISABLE_CLANG)
+ find_program(CLANG_CXX_PATH clang++-8)
+ find_program(CLANG_CXX_PATH clang++-7)
+ find_program(CLANG_CXX_PATH clang++-8.0)
+ find_program(CLANG_CXX_PATH clang++-7.0)
+ find_program(CLANG_CXX_PATH clang++-6.0)
+ find_program(CLANG_CXX_PATH clang++-5.0)
+ find_program(CLANG_CXX_PATH clang++-4.0)
find_program(CLANG_CXX_PATH clang++-4.0)
find_program(CLANG_CXX_PATH clang++-3.9)
find_program(CLANG_CXX_PATH clang++-3.8)
@@ -85,7 +92,7 @@ if (${CMAKE_GENERATOR} STREQUAL "MinGW Makefiles"
OR ${CMAKE_GENERATOR} STREQUAL "Unix Makefiles"
OR ${CMAKE_GENERATOR} STREQUAL "Ninja")
- if (CMAKE_CXX_COMPILER MATCHES "clang")
+ if (CMAKE_CXX_COMPILER MATCHES "clang" AND NOT CMAKE_CXX_COMPILER MATCHES "clang-cl")
if (WIN32)
# On windows, clang requires --target
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --target=x86_64-w64-windows-gnu" CACHE STRING "cxx compile flags" FORCE)
@@ -110,6 +117,12 @@ project(kfr CXX)
message(STATUS "C++ compiler: ${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION} ${CMAKE_CXX_COMPILER} ")
+if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
+ set(CLANG 1)
+else()
+ set(CLANG 0)
+endif()
+
# Binary output directories
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${PROJECT_BINARY_DIR}/bin)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${PROJECT_BINARY_DIR}/bin/Debug)
@@ -123,7 +136,11 @@ else ()
set(STD_LIB stdc++)
endif ()
-if (NOT MSVC)
+add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE)
+
+add_definitions(-D_CRT_SECURE_NO_WARNINGS)
+
+if (NOT MSVC OR CLANG)
# Enable C++14, disable exceptions and rtti
add_compile_options(-std=c++1y -fno-exceptions -fno-rtti )
if (NOT ARCH_FLAGS)
@@ -134,7 +151,9 @@ if (NOT MSVC)
else ()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_FLAGS}")
endif ()
- link_libraries(${STD_LIB} pthread m)
+ if(NOT MSVC)
+ link_libraries(${STD_LIB} pthread m)
+ endif()
else ()
# Disable exceptions
add_compile_options(/EHsc /D_HAS_EXCEPTIONS=0 /D_CRT_SECURE_NO_WARNINGS=1)
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
@@ -0,0 +1,113 @@
+jobs:
+- job: Linux
+ pool:
+ vmImage: 'ubuntu-16.04'
+ steps:
+ - bash: |
+ set -e
+ sudo apt-get install -y ninja-build libmpfr-dev
+ mkdir build
+ cd build
+ cmake -GNinja -DCMAKE_BUILD_TYPE=Release ..
+ ninja
+ cd tests
+ ctest -V
+
+- job: Linux_ARM
+ pool:
+ vmImage: 'ubuntu-16.04'
+ steps:
+ - bash: |
+ set -e
+ sudo apt-get install -y ninja-build g++-arm-linux-gnueabihf qemu
+ mkdir build
+ cd build
+ cmake -GNinja -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/arm.cmake -DARCH_FLAGS=-DLIBC_WORKAROUND_GETS=1 ..
+ ninja
+ cd tests
+ ctest -V
+
+- job: macOS
+ strategy:
+ matrix:
+ xcode10.1:
+ XCODE_VER: 10.1
+ xcode10:
+ XCODE_VER: 10
+ xcode9.4.1:
+ XCODE_VER: 9.4.1
+ xcode9.0.1:
+ XCODE_VER: 9.0.1
+ xcode8.3.3:
+ XCODE_VER: 8.3.3
+ pool:
+ vmImage: 'macOS-10.13'
+ steps:
+ - bash: |
+ set -e
+ /bin/bash -c "sudo xcode-select -s /Applications/Xcode_$(XCODE_VER).app/Contents/Developer"
+ brew install ninja
+ mkdir build
+ cd build
+ cmake -GNinja -DCMAKE_BUILD_TYPE=Release ..
+ ninja
+ cd tests
+ ctest -V
+
+- job: Windows_MinGW
+ pool:
+ vmImage: 'vs2017-win2016'
+ steps:
+ - bash: |
+ set -e
+ choco install llvm ninja
+ mkdir build
+ cd build
+ cmake -GNinja -DCMAKE_CXX_COMPILER="C:/Program Files/LLVM/bin/clang++.exe" -DCMAKE_BUILD_TYPE=Release ..
+ ninja
+ cd tests
+ ctest -V
+
+- job: Windows_MSVC64
+ pool:
+ vmImage: 'vs2017-win2016'
+ steps:
+ - script: |
+ choco install llvm ninja
+ mkdir build
+ cd build
+ call "C:\Program Files (x86)\Microsoft Visual Studio\2017\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
+ set PATH=%PATH:C:\tools\mingw64\bin;=%
+ set PATH=%PATH:C:\Program Files\Git\mingw64\bin;=%
+ cmake -GNinja -DCMAKE_CXX_COMPILER="C:/Program Files/LLVM/bin/clang-cl.exe" -DARCH_FLAGS=-mavx -DCMAKE_CXX_FLAGS=-m64 -DCMAKE_BUILD_TYPE=Release ..
+ ninja
+ cd tests
+ ctest -V
+
+- job: Windows_MSVC32
+ pool:
+ vmImage: 'vs2017-win2016'
+ steps:
+ - script: |
+ choco install llvm ninja
+ mkdir build
+ cd build
+ call "C:\Program Files (x86)\Microsoft Visual Studio\2017\Enterprise\VC\Auxiliary\Build\vcvars32.bat"
+ set PATH=%PATH:C:\tools\mingw64\bin;=%
+ set PATH=%PATH:C:\Program Files\Git\mingw64\bin;=%
+ cmake -GNinja -DCMAKE_CXX_COMPILER="C:/Program Files/LLVM/bin/clang-cl.exe" -DARCH_FLAGS=-mavx -DCMAKE_CXX_FLAGS=-m32 -DCMAKE_BUILD_TYPE=Release ..
+ ninja
+ cd tests
+ ctest -V
+
+- job: Windows_AVX512
+ pool: WIN-AVX512
+ steps:
+ - script: |
+ mkdir build
+ cd build
+ call "C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvars64.bat"
+ cmake -GNinja -DCMAKE_CXX_COMPILER="C:/LLVM/bin/clang-cl.exe" -DARCH_FLAGS="-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl" -DCMAKE_CXX_FLAGS=-m64 -DCMAKE_BUILD_TYPE=Release ..
+ ninja
+ cd tests
+ ctest -V
+\ No newline at end of file
diff --git a/cmake/arm.cmake b/cmake/arm.cmake
@@ -5,13 +5,13 @@ set (ARM True)
set (CMAKE_SYSTEM_PROCESSOR arm)
include (CMakeForceCompiler)
-CMAKE_FORCE_CXX_COMPILER (/usr/bin/clang++-3.8 Clang)
-CMAKE_FORCE_C_COMPILER (/usr/bin/clang-3.8 Clang)
+CMAKE_FORCE_CXX_COMPILER (/usr/bin/clang++ Clang)
+CMAKE_FORCE_C_COMPILER (/usr/bin/clang Clang)
set (CMAKE_CXX_COMPILER_WORKS TRUE)
set (CMAKE_C_COMPILER_WORKS TRUE)
set (ARM_ROOT "/usr/arm-linux-gnueabihf/include")
-set (GCC_VER 4.8.4)
+set (GCC_VER 5.4.0)
set (SYS_PATHS "-isystem ${ARM_ROOT}/c++/${GCC_VER} -isystem ${ARM_ROOT}/c++/${GCC_VER}/backward -isystem ${ARM_ROOT}/c++/${GCC_VER}/arm-linux-gnueabihf -isystem ${ARM_ROOT}")
set (ARM_COMMON_FLAGS "-target arm-linux-gnueabihf -mcpu=cortex-a15 -mfpu=neon-vfpv4 -mfloat-abi=hard -static")
diff --git a/include/kfr/base/abs.hpp b/include/kfr/base/abs.hpp
@@ -64,6 +64,17 @@ KFR_SINTRIN u16avx abs(const u16avx& x) { return x; }
KFR_SINTRIN u8avx abs(const u8avx& x) { return x; }
#endif
+#if defined CMT_ARCH_AVX512
+KFR_SINTRIN i64avx512 abs(const i64avx512& x) { return select(x >= 0, x, -x); }
+KFR_SINTRIN i32avx512 abs(const i32avx512& x) { return _mm512_abs_epi32(*x); }
+KFR_SINTRIN i16avx512 abs(const i16avx512& x) { return _mm512_abs_epi16(*x); }
+KFR_SINTRIN i8avx512 abs(const i8avx512& x) { return _mm512_abs_epi8(*x); }
+KFR_SINTRIN u64avx512 abs(const u64avx512& x) { return x; }
+KFR_SINTRIN u32avx512 abs(const u32avx512& x) { return x; }
+KFR_SINTRIN u16avx512 abs(const u16avx512& x) { return x; }
+KFR_SINTRIN u8avx512 abs(const u8avx512& x) { return x; }
+#endif
+
KFR_HANDLE_ALL_SIZES_NOT_F_1(abs)
#elif defined CMT_ARCH_NEON && defined KFR_NATIVE_INTRINSICS
@@ -108,7 +119,7 @@ KFR_SINTRIN vec<T, N> abs(const vec<T, N>& x)
}
#endif
KFR_I_CONVERTER(abs)
-}
+} // namespace intrinsics
KFR_I_FN(abs)
/**
@@ -128,4 +139,4 @@ KFR_INTRIN internal::expression_function<fn::abs, E1> abs(E1&& x)
{
return { fn::abs(), std::forward<E1>(x) };
}
-}
+} // namespace kfr
diff --git a/include/kfr/base/basic_expressions.hpp b/include/kfr/base/basic_expressions.hpp
@@ -75,7 +75,7 @@ struct expression_iterator
iterator end() const { return { *this, e1.size() }; }
E1 e1;
};
-}
+} // namespace internal
template <typename To, typename E>
CMT_INLINE internal::expression_convert<To, E> convert(E&& expr)
@@ -171,7 +171,7 @@ struct expression_writer
size_t m_position = 0;
E1 e1;
};
-}
+} // namespace internal
template <typename T, typename E1>
internal::expression_reader<T, E1> reader(E1&& e1)
@@ -350,13 +350,13 @@ struct expression_adjacent : expression_base<E>
{
const vec<T, N> in = this->argument_first(cinput, index, vec_t<T, N>());
const vec<T, N> delayed = insertleft(data, in);
- data = in[N - 1];
+ data = in[N - 1];
return this->fn(in, delayed);
}
Fn fn;
mutable value_type data = value_type(0);
};
-}
+} // namespace internal
template <typename E1>
CMT_INLINE internal::expression_slice<E1> slice(E1&& e1, size_t start, size_t size = infinite_size)
@@ -451,7 +451,7 @@ struct expression_padded : expression_base<E>
value_type fill_value;
const size_t input_size;
};
-}
+} // namespace internal
/**
* @brief Returns infinite template expression that pads e with fill_value (default value = 0)
@@ -535,7 +535,7 @@ private:
swallow{ (std::get<indices>(this->args)(coutput, index, xx[indices]), void(), 0)... };
}
};
-}
+} // namespace internal
template <typename... E, KFR_ENABLE_IF(is_output_expressions<E...>::value)>
internal::expression_unpack<E...> unpack(E&&... e)
@@ -564,9 +564,10 @@ struct task_partition
size_t count;
size_t operator()(size_t index)
{
- if (index > count)
+ if (index >= count)
return 0;
- return process(output, input, index * chunk_size, chunk_size);
+ return process(output, input, index * chunk_size,
+ index == count - 1 ? size - (count - 1) * chunk_size : chunk_size);
}
};
diff --git a/include/kfr/base/function.hpp b/include/kfr/base/function.hpp
@@ -78,6 +78,17 @@ using u16avx = vec<u16, 16>;
using u32avx = vec<u32, 8>;
using u64avx = vec<u64, 4>;
+using f32avx512 = vec<f32, 16>;
+using f64avx512 = vec<f64, 8>;
+using i8avx512 = vec<i8, 64>;
+using i16avx512 = vec<i16, 32>;
+using i32avx512 = vec<i32, 16>;
+using i64avx512 = vec<i64, 8>;
+using u8avx512 = vec<u8, 64>;
+using u16avx512 = vec<u16, 32>;
+using u32avx512 = vec<u32, 16>;
+using u64avx512 = vec<u64, 8>;
+
#else
using f32neon = vec<f32, 4>;
using f64neon = vec<f64, 2>;
@@ -252,6 +263,6 @@ inline T to_scalar(const vec<T, 1>& value)
{
return value[0];
}
-}
-}
+} // namespace intrinsics
+} // namespace kfr
CMT_PRAGMA_GNU(GCC diagnostic pop)
diff --git a/include/kfr/base/logical.hpp b/include/kfr/base/logical.hpp
@@ -53,6 +53,7 @@ struct bitmask
#if defined CMT_ARCH_SSE41
+// horizontal OR
KFR_SINTRIN bool bittestany(const u8sse& x) { return !_mm_testz_si128(*x, *x); }
KFR_SINTRIN bool bittestany(const u16sse& x) { return !_mm_testz_si128(*x, *x); }
KFR_SINTRIN bool bittestany(const u32sse& x) { return !_mm_testz_si128(*x, *x); }
@@ -62,6 +63,7 @@ KFR_SINTRIN bool bittestany(const i16sse& x) { return !_mm_testz_si128(*x, *x);
KFR_SINTRIN bool bittestany(const i32sse& x) { return !_mm_testz_si128(*x, *x); }
KFR_SINTRIN bool bittestany(const i64sse& x) { return !_mm_testz_si128(*x, *x); }
+// horizontal AND
KFR_SINTRIN bool bittestall(const u8sse& x) { return _mm_testc_si128(*x, *allonesvector(x)); }
KFR_SINTRIN bool bittestall(const u16sse& x) { return _mm_testc_si128(*x, *allonesvector(x)); }
KFR_SINTRIN bool bittestall(const u32sse& x) { return _mm_testc_si128(*x, *allonesvector(x)); }
@@ -73,17 +75,13 @@ KFR_SINTRIN bool bittestall(const i64sse& x) { return _mm_testc_si128(*x, *allon
#endif
#if defined CMT_ARCH_AVX
+// horizontal OR
KFR_SINTRIN bool bittestany(const f32sse& x) { return !_mm_testz_ps(*x, *x); }
KFR_SINTRIN bool bittestany(const f64sse& x) { return !_mm_testz_pd(*x, *x); }
-KFR_SINTRIN bool bittestall(const f32sse& x) { return _mm_testc_ps(*x, *allonesvector(x)); }
-KFR_SINTRIN bool bittestall(const f64sse& x) { return _mm_testc_pd(*x, *allonesvector(x)); }
KFR_SINTRIN bool bittestany(const f32avx& x) { return !_mm256_testz_ps(*x, *x); }
KFR_SINTRIN bool bittestany(const f64avx& x) { return !_mm256_testz_pd(*x, *x); }
-KFR_SINTRIN bool bittestnall(const f32avx& x) { return _mm256_testc_ps(*x, *allonesvector(x)); }
-KFR_SINTRIN bool bittestnall(const f64avx& x) { return _mm256_testc_pd(*x, *allonesvector(x)); }
-
KFR_SINTRIN bool bittestany(const u8avx& x) { return !_mm256_testz_si256(*x, *x); }
KFR_SINTRIN bool bittestany(const u16avx& x) { return !_mm256_testz_si256(*x, *x); }
KFR_SINTRIN bool bittestany(const u32avx& x) { return !_mm256_testz_si256(*x, *x); }
@@ -93,6 +91,13 @@ KFR_SINTRIN bool bittestany(const i16avx& x) { return !_mm256_testz_si256(*x, *x
KFR_SINTRIN bool bittestany(const i32avx& x) { return !_mm256_testz_si256(*x, *x); }
KFR_SINTRIN bool bittestany(const i64avx& x) { return !_mm256_testz_si256(*x, *x); }
+// horizontal AND
+KFR_SINTRIN bool bittestall(const f32sse& x) { return _mm_testc_ps(*x, *allonesvector(x)); }
+KFR_SINTRIN bool bittestall(const f64sse& x) { return _mm_testc_pd(*x, *allonesvector(x)); }
+
+KFR_SINTRIN bool bittestall(const f32avx& x) { return _mm256_testc_ps(*x, *allonesvector(x)); }
+KFR_SINTRIN bool bittestall(const f64avx& x) { return _mm256_testc_pd(*x, *allonesvector(x)); }
+
KFR_SINTRIN bool bittestall(const u8avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); }
KFR_SINTRIN bool bittestall(const u16avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); }
KFR_SINTRIN bool bittestall(const u32avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); }
@@ -101,6 +106,34 @@ KFR_SINTRIN bool bittestall(const i8avx& x) { return _mm256_testc_si256(*x, *all
KFR_SINTRIN bool bittestall(const i16avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); }
KFR_SINTRIN bool bittestall(const i32avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); }
KFR_SINTRIN bool bittestall(const i64avx& x) { return _mm256_testc_si256(*x, *allonesvector(x)); }
+
+#if defined CMT_ARCH_AVX512
+// horizontal OR
+KFR_SINTRIN bool bittestany(const f32avx512& x) { return _mm512_test_epi32_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const f64avx512& x) { return _mm512_test_epi64_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const u8avx512& x) { return _mm512_test_epi8_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const u16avx512& x) { return _mm512_test_epi16_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const u32avx512& x) { return _mm512_test_epi32_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const u64avx512& x) { return _mm512_test_epi64_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const i8avx512& x) { return _mm512_test_epi8_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const i16avx512& x) { return _mm512_test_epi16_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const i32avx512& x) { return _mm512_test_epi32_mask(*x, *x); }
+KFR_SINTRIN bool bittestany(const i64avx512& x) { return _mm512_test_epi64_mask(*x, *x); }
+
+// horizontal AND
+KFR_SINTRIN bool bittestall(const f32avx512& x) { return !bittestany(~x); }
+KFR_SINTRIN bool bittestall(const f64avx512& x) { return !bittestany(~x); }
+KFR_SINTRIN bool bittestall(const u8avx512& x) { return !bittestany(~x); }
+KFR_SINTRIN bool bittestall(const u16avx512& x) { return !bittestany(~x); }
+KFR_SINTRIN bool bittestall(const u32avx512& x) { return !bittestany(~x); }
+KFR_SINTRIN bool bittestall(const u64avx512& x) { return !bittestany(~x); }
+KFR_SINTRIN bool bittestall(const i8avx512& x) { return !bittestany(~x); }
+KFR_SINTRIN bool bittestall(const i16avx512& x) { return !bittestany(~x); }
+KFR_SINTRIN bool bittestall(const i32avx512& x) { return !bittestany(~x); }
+KFR_SINTRIN bool bittestall(const i64avx512& x) { return !bittestany(~x); }
+
+#endif
+
#elif defined CMT_ARCH_SSE41
KFR_SINTRIN bool bittestany(const f32sse& x) { return !_mm_testz_si128(*bitcast<u8>(x), *bitcast<u8>(x)); }
KFR_SINTRIN bool bittestany(const f64sse& x) { return !_mm_testz_si128(*bitcast<u8>(x), *bitcast<u8>(x)); }
@@ -249,7 +282,7 @@ KFR_SINTRIN bool bittestall(const vec<T, N>& x, const vec<T, N>& y)
return !bittestany(~x & y);
}
#endif
-}
+} // namespace intrinsics
/**
* @brief Returns x[0] && x[1] && ... && x[N-1]
@@ -268,4 +301,4 @@ KFR_SINTRIN bool any(const mask<T, N>& x)
{
return intrinsics::bittestany(x.asvec());
}
-}
+} // namespace kfr
diff --git a/include/kfr/base/min_max.hpp b/include/kfr/base/min_max.hpp
@@ -42,15 +42,11 @@ KFR_SINTRIN f32sse min(const f32sse& x, const f32sse& y) { return _mm_min_ps(*x,
KFR_SINTRIN f64sse min(const f64sse& x, const f64sse& y) { return _mm_min_pd(*x, *y); }
KFR_SINTRIN u8sse min(const u8sse& x, const u8sse& y) { return _mm_min_epu8(*x, *y); }
KFR_SINTRIN i16sse min(const i16sse& x, const i16sse& y) { return _mm_min_epi16(*x, *y); }
-KFR_SINTRIN i64sse min(const i64sse& x, const i64sse& y) { return select(x < y, x, y); }
-KFR_SINTRIN u64sse min(const u64sse& x, const u64sse& y) { return select(x < y, x, y); }
KFR_SINTRIN f32sse max(const f32sse& x, const f32sse& y) { return _mm_max_ps(*x, *y); }
KFR_SINTRIN f64sse max(const f64sse& x, const f64sse& y) { return _mm_max_pd(*x, *y); }
KFR_SINTRIN u8sse max(const u8sse& x, const u8sse& y) { return _mm_max_epu8(*x, *y); }
KFR_SINTRIN i16sse max(const i16sse& x, const i16sse& y) { return _mm_max_epi16(*x, *y); }
-KFR_SINTRIN i64sse max(const i64sse& x, const i64sse& y) { return select(x > y, x, y); }
-KFR_SINTRIN u64sse max(const u64sse& x, const u64sse& y) { return select(x > y, x, y); }
#if defined CMT_ARCH_AVX2
KFR_SINTRIN u8avx min(const u8avx& x, const u8avx& y) { return _mm256_min_epu8(*x, *y); }
@@ -67,6 +63,40 @@ KFR_SINTRIN u16avx max(const u16avx& x, const u16avx& y) { return _mm256_max_epu
KFR_SINTRIN i32avx max(const i32avx& x, const i32avx& y) { return _mm256_max_epi32(*x, *y); }
KFR_SINTRIN u32avx max(const u32avx& x, const u32avx& y) { return _mm256_max_epu32(*x, *y); }
+#endif
+
+#if defined CMT_ARCH_AVX512
+KFR_SINTRIN u8avx512 min(const u8avx512& x, const u8avx512& y) { return _mm512_min_epu8(*x, *y); }
+KFR_SINTRIN i16avx512 min(const i16avx512& x, const i16avx512& y) { return _mm512_min_epi16(*x, *y); }
+KFR_SINTRIN i8avx512 min(const i8avx512& x, const i8avx512& y) { return _mm512_min_epi8(*x, *y); }
+KFR_SINTRIN u16avx512 min(const u16avx512& x, const u16avx512& y) { return _mm512_min_epu16(*x, *y); }
+KFR_SINTRIN i32avx512 min(const i32avx512& x, const i32avx512& y) { return _mm512_min_epi32(*x, *y); }
+KFR_SINTRIN u32avx512 min(const u32avx512& x, const u32avx512& y) { return _mm512_min_epu32(*x, *y); }
+KFR_SINTRIN u8avx512 max(const u8avx512& x, const u8avx512& y) { return _mm512_max_epu8(*x, *y); }
+KFR_SINTRIN i16avx512 max(const i16avx512& x, const i16avx512& y) { return _mm512_max_epi16(*x, *y); }
+KFR_SINTRIN i8avx512 max(const i8avx512& x, const i8avx512& y) { return _mm512_max_epi8(*x, *y); }
+KFR_SINTRIN u16avx512 max(const u16avx512& x, const u16avx512& y) { return _mm512_max_epu16(*x, *y); }
+KFR_SINTRIN i32avx512 max(const i32avx512& x, const i32avx512& y) { return _mm512_max_epi32(*x, *y); }
+KFR_SINTRIN u32avx512 max(const u32avx512& x, const u32avx512& y) { return _mm512_max_epu32(*x, *y); }
+KFR_SINTRIN i64avx512 min(const i64avx512& x, const i64avx512& y) { return _mm512_min_epi64(*x, *y); }
+KFR_SINTRIN u64avx512 min(const u64avx512& x, const u64avx512& y) { return _mm512_min_epu64(*x, *y); }
+KFR_SINTRIN i64avx512 max(const i64avx512& x, const i64avx512& y) { return _mm512_max_epi64(*x, *y); }
+KFR_SINTRIN u64avx512 max(const u64avx512& x, const u64avx512& y) { return _mm512_max_epu64(*x, *y); }
+
+KFR_SINTRIN i64avx min(const i64avx& x, const i64avx& y) { return _mm256_min_epi64(*x, *y); }
+KFR_SINTRIN u64avx min(const u64avx& x, const u64avx& y) { return _mm256_min_epu64(*x, *y); }
+KFR_SINTRIN i64avx max(const i64avx& x, const i64avx& y) { return _mm256_max_epi64(*x, *y); }
+KFR_SINTRIN u64avx max(const u64avx& x, const u64avx& y) { return _mm256_max_epu64(*x, *y); }
+
+KFR_SINTRIN i64sse min(const i64sse& x, const i64sse& y) { return _mm_min_epi64(*x, *y); }
+KFR_SINTRIN u64sse min(const u64sse& x, const u64sse& y) { return _mm_min_epu64(*x, *y); }
+KFR_SINTRIN i64sse max(const i64sse& x, const i64sse& y) { return _mm_max_epi64(*x, *y); }
+KFR_SINTRIN u64sse max(const u64sse& x, const u64sse& y) { return _mm_max_epu64(*x, *y); }
+#else
+KFR_SINTRIN i64sse min(const i64sse& x, const i64sse& y) { return select(x < y, x, y); }
+KFR_SINTRIN u64sse min(const u64sse& x, const u64sse& y) { return select(x < y, x, y); }
+KFR_SINTRIN i64sse max(const i64sse& x, const i64sse& y) { return select(x > y, x, y); }
+KFR_SINTRIN u64sse max(const u64sse& x, const u64sse& y) { return select(x > y, x, y); }
KFR_SINTRIN i64avx min(const i64avx& x, const i64avx& y) { return select(x < y, x, y); }
KFR_SINTRIN u64avx min(const u64avx& x, const u64avx& y) { return select(x < y, x, y); }
KFR_SINTRIN i64avx max(const i64avx& x, const i64avx& y) { return select(x > y, x, y); }
@@ -193,7 +223,7 @@ KFR_I_CONVERTER(min)
KFR_I_CONVERTER(max)
KFR_I_CONVERTER(absmin)
KFR_I_CONVERTER(absmax)
-}
+} // namespace intrinsics
KFR_I_FN(min)
KFR_I_FN(max)
KFR_I_FN(absmin)
@@ -274,4 +304,4 @@ KFR_INTRIN internal::expression_function<fn::absmax, E1, E2> absmax(E1&& x, E2&&
{
return { fn::absmax(), std::forward<E1>(x), std::forward<E2>(y) };
}
-}
+} // namespace kfr
diff --git a/include/kfr/base/pointer.hpp b/include/kfr/base/pointer.hpp
@@ -116,7 +116,7 @@ struct expression_resource
};
template <typename E>
-struct alignas(const_max(size_t(8), alignof(E))) expression_resource_impl : expression_resource
+struct expression_resource_impl : expression_resource
{
expression_resource_impl(E&& e) noexcept : e(std::move(e)) {}
virtual ~expression_resource_impl() {}
diff --git a/include/kfr/base/round.hpp b/include/kfr/base/round.hpp
@@ -34,25 +34,32 @@ namespace kfr
namespace intrinsics
{
-#define KFR_mm_trunc_ps(V) _mm_round_ps((V), _MM_FROUND_TRUNC)
-#define KFR_mm_roundnearest_ps(V) _mm_round_ps((V), _MM_FROUND_NINT)
-#define KFR_mm_trunc_pd(V) _mm_round_pd((V), _MM_FROUND_TRUNC)
-#define KFR_mm_roundnearest_pd(V) _mm_round_pd((V), _MM_FROUND_NINT)
-
-#define KFR_mm_trunc_ss(V) _mm_round_ss(_mm_setzero_ps(), (V), _MM_FROUND_TRUNC)
-#define KFR_mm_roundnearest_ss(V) _mm_round_ss(_mm_setzero_ps(), (V), _MM_FROUND_NINT)
-#define KFR_mm_trunc_sd(V) _mm_round_sd(_mm_setzero_pd(), (V), _MM_FROUND_TRUNC)
-#define KFR_mm_roundnearest_sd(V) _mm_round_sd(_mm_setzero_pd(), (V), _MM_FROUND_NINT)
+#define KFR_mm_trunc_ps(V) _mm_round_ps((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm_roundnearest_ps(V) _mm_round_ps((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
+#define KFR_mm_trunc_pd(V) _mm_round_pd((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm_roundnearest_pd(V) _mm_round_pd((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
+
+#define KFR_mm_trunc_ss(V) _mm_round_ss(_mm_setzero_ps(), (V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm_roundnearest_ss(V) \
+ _mm_round_ss(_mm_setzero_ps(), (V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
+#define KFR_mm_trunc_sd(V) _mm_round_sd(_mm_setzero_pd(), (V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm_roundnearest_sd(V) \
+ _mm_round_sd(_mm_setzero_pd(), (V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
#define KFR_mm_floor_ss(V) _mm_floor_ss(_mm_setzero_ps(), (V))
#define KFR_mm_floor_sd(V) _mm_floor_sd(_mm_setzero_pd(), (V))
#define KFR_mm_ceil_ss(V) _mm_ceil_ss(_mm_setzero_ps(), (V))
#define KFR_mm_ceil_sd(V) _mm_ceil_sd(_mm_setzero_pd(), (V))
-#define KFR_mm256_trunc_ps(V) _mm256_round_ps((V), _MM_FROUND_TRUNC)
-#define KFR_mm256_roundnearest_ps(V) _mm256_round_ps((V), _MM_FROUND_NINT)
-#define KFR_mm256_trunc_pd(V) _mm256_round_pd((V), _MM_FROUND_TRUNC)
-#define KFR_mm256_roundnearest_pd(V) _mm256_round_pd((V), _MM_FROUND_NINT)
+#define KFR_mm256_trunc_ps(V) _mm256_round_ps((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm256_roundnearest_ps(V) _mm256_round_ps((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
+#define KFR_mm256_trunc_pd(V) _mm256_round_pd((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm256_roundnearest_pd(V) _mm256_round_pd((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
+
+#define KFR_mm512_trunc_ps(V) _mm512_roundscale_ps((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm512_roundnearest_ps(V) _mm512_roundscale_ps((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
+#define KFR_mm512_trunc_pd(V) _mm512_roundscale_pd((V), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)
+#define KFR_mm512_roundnearest_pd(V) _mm512_roundscale_pd((V), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
#if defined CMT_ARCH_SSE41 && defined KFR_NATIVE_INTRINSICS
@@ -81,6 +88,20 @@ KFR_SINTRIN f32avx fract(const f32avx& x) { return x - floor(x); }
KFR_SINTRIN f64avx fract(const f64avx& x) { return x - floor(x); }
#endif
+#if defined CMT_ARCH_AVX512
+
+KFR_SINTRIN f32avx512 floor(const f32avx512& value) { return _mm512_floor_ps(*value); }
+KFR_SINTRIN f32avx512 ceil(const f32avx512& value) { return _mm512_ceil_ps(*value); }
+KFR_SINTRIN f32avx512 trunc(const f32avx512& value) { return KFR_mm512_trunc_ps(*value); }
+KFR_SINTRIN f32avx512 round(const f32avx512& value) { return KFR_mm512_roundnearest_ps(*value); }
+KFR_SINTRIN f64avx512 floor(const f64avx512& value) { return _mm512_floor_pd(*value); }
+KFR_SINTRIN f64avx512 ceil(const f64avx512& value) { return _mm512_ceil_pd(*value); }
+KFR_SINTRIN f64avx512 trunc(const f64avx512& value) { return KFR_mm512_trunc_pd(*value); }
+KFR_SINTRIN f64avx512 round(const f64avx512& value) { return KFR_mm512_roundnearest_pd(*value); }
+KFR_SINTRIN f32avx512 fract(const f32avx512& x) { return x - floor(x); }
+KFR_SINTRIN f64avx512 fract(const f64avx512& x) { return x - floor(x); }
+#endif
+
KFR_HANDLE_ALL_SIZES_F_1(floor)
KFR_HANDLE_ALL_SIZES_F_1(ceil)
KFR_HANDLE_ALL_SIZES_F_1(round)
@@ -203,7 +224,7 @@ KFR_I_CONVERTER(ifloor)
KFR_I_CONVERTER(iceil)
KFR_I_CONVERTER(iround)
KFR_I_CONVERTER(itrunc)
-}
+} // namespace intrinsics
KFR_I_FN(floor)
KFR_I_FN(ceil)
KFR_I_FN(round)
@@ -339,7 +360,7 @@ CMT_INLINE vec<T, N> rem(const vec<T, N>& x, const vec<T, N>& y)
{
return fmod(x, y);
}
-}
+} // namespace kfr
#undef KFR_mm_trunc_ps
#undef KFR_mm_roundnearest_ps
diff --git a/include/kfr/base/saturation.hpp b/include/kfr/base/saturation.hpp
@@ -40,10 +40,10 @@ KFR_SINTRIN vec<T, N> saturated_signed_add(const vec<T, N>& a, const vec<T, N>&
{
using UT = utype<T>;
constexpr size_t shift = typebits<UT>::bits - 1;
- vec<UT, N> aa = bitcast<UT>(a);
- vec<UT, N> bb = bitcast<UT>(b);
- const vec<UT, N> sum = aa + bb;
- aa = (aa >> shift) + static_cast<UT>(std::numeric_limits<T>::max());
+ vec<UT, N> aa = bitcast<UT>(a);
+ vec<UT, N> bb = bitcast<UT>(b);
+ const vec<UT, N> sum = aa + bb;
+ aa = (aa >> shift) + static_cast<UT>(std::numeric_limits<T>::max());
return select(bitcast<T>((aa ^ bb) | ~(bb ^ sum)) >= 0, a, bitcast<T>(sum));
}
@@ -52,10 +52,10 @@ KFR_SINTRIN vec<T, N> saturated_signed_sub(const vec<T, N>& a, const vec<T, N>&
{
using UT = utype<T>;
constexpr size_t shift = typebits<UT>::bits - 1;
- vec<UT, N> aa = bitcast<UT>(a);
- vec<UT, N> bb = bitcast<UT>(b);
- const vec<UT, N> diff = aa - bb;
- aa = (aa >> shift) + static_cast<UT>(std::numeric_limits<T>::max());
+ vec<UT, N> aa = bitcast<UT>(a);
+ vec<UT, N> bb = bitcast<UT>(b);
+ const vec<UT, N> diff = aa - bb;
+ aa = (aa >> shift) + static_cast<UT>(std::numeric_limits<T>::max());
return select(bitcast<T>((aa ^ bb) & (aa ^ diff)) < 0, a, bitcast<T>(diff));
}
@@ -103,6 +103,36 @@ KFR_SINTRIN u8avx satsub(const u8avx& x, const u8avx& y) { return _mm256_subs_ep
KFR_SINTRIN i8avx satsub(const i8avx& x, const i8avx& y) { return _mm256_subs_epi8(*x, *y); }
KFR_SINTRIN u16avx satsub(const u16avx& x, const u16avx& y) { return _mm256_subs_epu16(*x, *y); }
KFR_SINTRIN i16avx satsub(const i16avx& x, const i16avx& y) { return _mm256_subs_epi16(*x, *y); }
+
+KFR_SINTRIN i32avx satadd(const i32avx& a, const i32avx& b) { return saturated_signed_add(a, b); }
+KFR_SINTRIN i64avx satadd(const i64avx& a, const i64avx& b) { return saturated_signed_add(a, b); }
+KFR_SINTRIN u32avx satadd(const u32avx& a, const u32avx& b) { return saturated_unsigned_add(a, b); }
+KFR_SINTRIN u64avx satadd(const u64avx& a, const u64avx& b) { return saturated_unsigned_add(a, b); }
+
+KFR_SINTRIN i32avx satsub(const i32avx& a, const i32avx& b) { return saturated_signed_sub(a, b); }
+KFR_SINTRIN i64avx satsub(const i64avx& a, const i64avx& b) { return saturated_signed_sub(a, b); }
+KFR_SINTRIN u32avx satsub(const u32avx& a, const u32avx& b) { return saturated_unsigned_sub(a, b); }
+KFR_SINTRIN u64avx satsub(const u64avx& a, const u64avx& b) { return saturated_unsigned_sub(a, b); }
+#endif
+
+#if defined CMT_ARCH_AVX512
+KFR_SINTRIN u8avx512 satadd(const u8avx512& x, const u8avx512& y) { return _mm512_adds_epu8(*x, *y); }
+KFR_SINTRIN i8avx512 satadd(const i8avx512& x, const i8avx512& y) { return _mm512_adds_epi8(*x, *y); }
+KFR_SINTRIN u16avx512 satadd(const u16avx512& x, const u16avx512& y) { return _mm512_adds_epu16(*x, *y); }
+KFR_SINTRIN i16avx512 satadd(const i16avx512& x, const i16avx512& y) { return _mm512_adds_epi16(*x, *y); }
+KFR_SINTRIN u8avx512 satsub(const u8avx512& x, const u8avx512& y) { return _mm512_subs_epu8(*x, *y); }
+KFR_SINTRIN i8avx512 satsub(const i8avx512& x, const i8avx512& y) { return _mm512_subs_epi8(*x, *y); }
+KFR_SINTRIN u16avx512 satsub(const u16avx512& x, const u16avx512& y) { return _mm512_subs_epu16(*x, *y); }
+KFR_SINTRIN i16avx512 satsub(const i16avx512& x, const i16avx512& y) { return _mm512_subs_epi16(*x, *y); }
+
+KFR_SINTRIN i32avx512 satadd(const i32avx512& a, const i32avx512& b) { return saturated_signed_add(a, b); }
+KFR_SINTRIN i64avx512 satadd(const i64avx512& a, const i64avx512& b) { return saturated_signed_add(a, b); }
+KFR_SINTRIN u32avx512 satadd(const u32avx512& a, const u32avx512& b) { return saturated_unsigned_add(a, b); }
+KFR_SINTRIN u64avx512 satadd(const u64avx512& a, const u64avx512& b) { return saturated_unsigned_add(a, b); }
+KFR_SINTRIN i32avx512 satsub(const i32avx512& a, const i32avx512& b) { return saturated_signed_sub(a, b); }
+KFR_SINTRIN i64avx512 satsub(const i64avx512& a, const i64avx512& b) { return saturated_signed_sub(a, b); }
+KFR_SINTRIN u32avx512 satsub(const u32avx512& a, const u32avx512& b) { return saturated_unsigned_sub(a, b); }
+KFR_SINTRIN u64avx512 satsub(const u64avx512& a, const u64avx512& b) { return saturated_unsigned_sub(a, b); }
#endif
KFR_HANDLE_ALL_SIZES_2(satadd)
@@ -156,7 +186,7 @@ KFR_SINTRIN vec<T, N> satsub(const vec<T, N>& a, const vec<T, N>& b)
#endif
KFR_I_CONVERTER(satadd)
KFR_I_CONVERTER(satsub)
-}
+} // namespace intrinsics
KFR_I_FN(satadd)
KFR_I_FN(satsub)
@@ -189,4 +219,4 @@ KFR_INTRIN internal::expression_function<fn::satsub, E1, E2> satsub(E1&& x, E2&&
{
return { fn::satsub(), std::forward<E1>(x), std::forward<E2>(y) };
}
-}
+} // namespace kfr
diff --git a/include/kfr/base/select.hpp b/include/kfr/base/select.hpp
@@ -121,6 +121,49 @@ KFR_SINTRIN i64avx select(const maskfor<i64avx>& m, const i64avx& x, const i64av
}
#endif
+#if defined CMT_ARCH_AVX512
+KFR_SINTRIN f64avx512 select(const maskfor<f64avx512>& m, const f64avx512& x, const f64avx512& y)
+{
+ return _mm512_mask_blend_pd(_mm512_test_epi64_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN f32avx512 select(const maskfor<f32avx512>& m, const f32avx512& x, const f32avx512& y)
+{
+ return _mm512_mask_blend_ps(_mm512_test_epi32_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN u8avx512 select(const maskfor<u8avx512>& m, const u8avx512& x, const u8avx512& y)
+{
+ return _mm512_mask_blend_epi8(_mm512_test_epi8_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN u16avx512 select(const maskfor<u16avx512>& m, const u16avx512& x, const u16avx512& y)
+{
+ return _mm512_mask_blend_epi16(_mm512_test_epi16_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN u32avx512 select(const maskfor<u32avx512>& m, const u32avx512& x, const u32avx512& y)
+{
+ return _mm512_mask_blend_epi32(_mm512_test_epi32_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN u64avx512 select(const maskfor<u64avx512>& m, const u64avx512& x, const u64avx512& y)
+{
+ return _mm512_mask_blend_epi64(_mm512_test_epi64_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN i8avx512 select(const maskfor<i8avx512>& m, const i8avx512& x, const i8avx512& y)
+{
+ return _mm512_mask_blend_epi8(_mm512_test_epi8_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN i16avx512 select(const maskfor<i16avx512>& m, const i16avx512& x, const i16avx512& y)
+{
+ return _mm512_mask_blend_epi16(_mm512_test_epi16_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN i32avx512 select(const maskfor<i32avx512>& m, const i32avx512& x, const i32avx512& y)
+{
+ return _mm512_mask_blend_epi32(_mm512_test_epi32_mask(*m, *m), *y, *x);
+}
+KFR_SINTRIN i64avx512 select(const maskfor<i64avx512>& m, const i64avx512& x, const i64avx512& y)
+{
+ return _mm512_mask_blend_epi64(_mm512_test_epi64_mask(*m, *m), *y, *x);
+}
+#endif
+
template <typename T, size_t N, KFR_ENABLE_IF(N < platform<T>::vector_width)>
KFR_SINTRIN vec<T, N> select(const mask<T, N>& a, const vec<T, N>& b, const vec<T, N>& c)
{
@@ -211,7 +254,7 @@ KFR_SINTRIN vec<T, N> select(const vec<T, N>& m, const vec<T, N>& x, const vec<T
{
return select(m.asmask(), x, y);
}
-}
+} // namespace intrinsics
KFR_I_FN(select)
/**
@@ -238,4 +281,4 @@ KFR_INTRIN internal::expression_function<fn::select, E1, E2, E3> select(E1&& m,
{
return { fn::select(), std::forward<E1>(m), std::forward<E2>(x), std::forward<E3>(y) };
}
-}
+} // namespace kfr
diff --git a/include/kfr/base/simd_intrin.hpp b/include/kfr/base/simd_intrin.hpp
@@ -91,6 +91,19 @@ KFR_SIMD_SPEC_TYPE(f32, 8, __m256);
KFR_SIMD_SPEC_TYPE(f64, 4, __m256d);
#endif
+#ifdef CMT_ARCH_AVX512
+KFR_SIMD_SPEC_TYPE(u8, 64, __m512i);
+KFR_SIMD_SPEC_TYPE(u16, 32, __m512i);
+KFR_SIMD_SPEC_TYPE(u32, 16, __m512i);
+KFR_SIMD_SPEC_TYPE(u64, 8, __m512i);
+KFR_SIMD_SPEC_TYPE(i8, 64, __m512i);
+KFR_SIMD_SPEC_TYPE(i16, 32, __m512i);
+KFR_SIMD_SPEC_TYPE(i32, 16, __m512i);
+KFR_SIMD_SPEC_TYPE(i64, 8, __m512i);
+KFR_SIMD_SPEC_TYPE(f32, 16, __m512);
+KFR_SIMD_SPEC_TYPE(f64, 8, __m512d);
+#endif
+
#ifdef CMT_ARCH_NEON
KFR_SIMD_SPEC_TYPE(u8, 16, uint8x16_t);
KFR_SIMD_SPEC_TYPE(u16, 8, uint16x8_t);
@@ -118,17 +131,17 @@ struct raw_bytes
#define KFR_C_CYCLE(...) \
for (size_t i = 0; i < N; i++) \
- vs[i] = __VA_ARGS__
+ vs[i] = __VA_ARGS__
#define KFR_R_CYCLE(...) \
vec<T, N> result; \
- for (size_t i = 0; i < N; i++) \
+ for (size_t i = 0; i < N; i++) \
result.vs[i] = __VA_ARGS__; \
return result
#define KFR_B_CYCLE(...) \
vec<T, N> result; \
- for (size_t i = 0; i < N; i++) \
+ for (size_t i = 0; i < N; i++) \
result.vs[i] = (__VA_ARGS__) ? constants<value_type>::allones() : value_type(0); \
return result
@@ -282,13 +295,13 @@ struct alignas(const_min(platform<>::maximum_vector_alignment, sizeof(T) * next_
KFR_I_CE vec& operator++() noexcept { return *this = *this + vec(1); }
KFR_I_CE vec& operator--() noexcept { return *this = *this - vec(1); }
- KFR_I_CE vec operator++(int)noexcept
+ KFR_I_CE vec operator++(int) noexcept
{
const vec z = *this;
++*this;
return z;
}
- KFR_I_CE vec operator--(int)noexcept
+ KFR_I_CE vec operator--(int) noexcept
{
const vec z = *this;
--*this;
@@ -321,6 +334,7 @@ struct alignas(const_min(platform<>::maximum_vector_alignment, sizeof(T) * next_
const vec& flatten() const noexcept { return *this; }
simd_type operator*() const noexcept { return simd; }
simd_type& operator*() noexcept { return simd; }
+
protected:
template <typename, size_t>
friend struct vec;
@@ -366,13 +380,13 @@ CMT_INLINE vec<T, csum<size_t, N1, N2, Sizes...>()> concat_impl(const vec<T, N1>
{
return concat_impl(concat_impl(x, y), args...);
}
-}
+} // namespace internal
template <typename T, size_t... Ns>
constexpr inline vec<T, csum<size_t, Ns...>()> concat(const vec<T, Ns>&... vs) noexcept
{
return internal::concat_impl(vs...);
}
-}
+} // namespace kfr
CMT_PRAGMA_MSVC(warning(pop))
diff --git a/include/kfr/base/simd_x86.hpp b/include/kfr/base/simd_x86.hpp
@@ -181,4 +181,92 @@ KFR_I_CE CMT_INLINE vec<f64, 4> vec<f64, 4>::operator^(const vec<f64, 4>& y) con
#endif // CMT_ARCH_AVX
-} // namespace kf
+#ifdef CMT_ARCH_AVX512
+
+template <>
+KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator+(const vec<f32, 16>& y) const noexcept
+{
+ return _mm512_add_ps(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator-(const vec<f32, 16>& y) const noexcept
+{
+ return _mm512_sub_ps(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator*(const vec<f32, 16>& y) const noexcept
+{
+ return _mm512_mul_ps(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator/(const vec<f32, 16>& y) const noexcept
+{
+ return _mm512_div_ps(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator&(const vec<f32, 16>& y) const noexcept
+{
+ return _mm512_and_ps(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator|(const vec<f32, 16>& y) const noexcept
+{
+ return _mm512_or_ps(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f32, 16> vec<f32, 16>::operator^(const vec<f32, 16>& y) const noexcept
+{
+ return _mm512_xor_ps(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator+(const vec<f64, 8>& y) const noexcept
+{
+ return _mm512_add_pd(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator-(const vec<f64, 8>& y) const noexcept
+{
+ return _mm512_sub_pd(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator*(const vec<f64, 8>& y) const noexcept
+{
+ return _mm512_mul_pd(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator/(const vec<f64, 8>& y) const noexcept
+{
+ return _mm512_div_pd(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator&(const vec<f64, 8>& y) const noexcept
+{
+ return _mm512_and_pd(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator|(const vec<f64, 8>& y) const noexcept
+{
+ return _mm512_or_pd(simd, y.simd);
+}
+
+template <>
+KFR_I_CE CMT_INLINE vec<f64, 8> vec<f64, 8>::operator^(const vec<f64, 8>& y) const noexcept
+{
+ return _mm512_xor_pd(simd, y.simd);
+}
+
+#endif // CMT_ARCH_AVX
+
+} // namespace kfr
diff --git a/include/kfr/base/sqrt.hpp b/include/kfr/base/sqrt.hpp
@@ -48,6 +48,11 @@ KFR_SINTRIN f32avx sqrt(const f32avx& x) { return _mm256_sqrt_ps(*x); }
KFR_SINTRIN f64avx sqrt(const f64avx& x) { return _mm256_sqrt_pd(*x); }
#endif
+#if defined CMT_ARCH_AVX512
+KFR_SINTRIN f32avx512 sqrt(const f32avx512& x) { return _mm512_sqrt_ps(*x); }
+KFR_SINTRIN f64avx512 sqrt(const f64avx512& x) { return _mm512_sqrt_pd(*x); }
+#endif
+
KFR_HANDLE_ALL_SIZES_FLT_1(sqrt)
#else
diff --git a/include/kfr/base/types.hpp b/include/kfr/base/types.hpp
@@ -375,7 +375,7 @@ struct is_simd_type
template <typename T, size_t N>
struct vec_t
{
- static_assert(N > 0 && N <= 256, "Invalid vector size");
+ static_assert(N > 0 && N <= 1024, "Invalid vector size");
static_assert(is_simd_type<T>::value || !compound_type_traits<T>::is_scalar, "Invalid vector type");
diff --git a/include/kfr/dft/fft.hpp b/include/kfr/dft/fft.hpp
@@ -31,6 +31,7 @@
#include "../base/read_write.hpp"
#include "../base/small_buffer.hpp"
#include "../base/vec.hpp"
+#include "../testo/assert.hpp"
#include "bitrev.hpp"
#include "ft.hpp"
@@ -46,6 +47,11 @@ CMT_PRAGMA_MSVC(warning(disable : 4100))
namespace kfr
{
+#define DFT_ASSERT TESTO_ASSERT_ACTIVE
+
+template <typename T>
+constexpr size_t fft_vector_width = platform<T>::vector_width;
+
template <typename T>
struct dft_stage
{
@@ -83,11 +89,11 @@ KFR_SINTRIN cvec<T, width> radix4_apply_twiddle(csize_t<width>, cfalse_t /*split
cvec<T, width> ww = w;
cvec<T, width> tw_ = tw;
cvec<T, width> b1 = ww * dupeven(tw_);
- ww = swap<2>(ww);
+ ww = swap<2>(ww);
if (inverse)
tw_ = -(tw_);
- ww = subadd(b1, ww * dupodd(tw_));
+ ww = subadd(b1, ww * dupodd(tw_));
return ww;
}
@@ -235,8 +241,8 @@ KFR_SINTRIN void initialize_twiddles_impl(complex<T>*& twiddle, size_t nn, size_
for (size_t i = 0; i < width; i++)
{
const cvec<T, 1> r = calculate_twiddle<T>(nn + nnstep * i, size);
- result[i * 2] = r[0];
- result[i * 2 + 1] = r[1];
+ result[i * 2] = r[0];
+ result[i * 2 + 1] = r[1];
}
if (split_format)
ref_cast<cvec<T, width>>(twiddle[0]) = splitpairs(result);
@@ -248,9 +254,11 @@ KFR_SINTRIN void initialize_twiddles_impl(complex<T>*& twiddle, size_t nn, size_
template <typename T, size_t width>
CMT_NOINLINE void initialize_twiddles(complex<T>*& twiddle, size_t stage_size, size_t size, bool split_format)
{
+ const size_t count = stage_size / 4;
+ // DFT_ASSERT(width <= count);
size_t nnstep = size / stage_size;
CMT_LOOP_NOUNROLL
- for (size_t n = 0; n < stage_size / 4; n += width)
+ for (size_t n = 0; n < count; n += width)
{
initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 1, nnstep * 1, size, split_format);
initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 2, nnstep * 2, size, split_format);
@@ -295,6 +303,7 @@ KFR_SINTRIN cfalse_t radix4_pass(Ntype N, size_t blocks, csize_t<width>, cbool_t
CMT_ASSUME(blocks > 0);
CMT_ASSUME(N > 0);
CMT_ASSUME(N4 > 0);
+ DFT_ASSERT(width <= N4);
CMT_LOOP_NOUNROLL for (size_t b = 0; b < blocks; b++)
{
CMT_PRAGMA_CLANG(clang loop unroll_count(2))
@@ -359,6 +368,7 @@ KFR_SINTRIN ctrue_t radix4_pass(csize_t<8>, size_t blocks, csize_t<width>, cfals
{
CMT_ASSUME(blocks > 0);
constexpr static size_t prefetch_offset = width * 16;
+ DFT_ASSERT(2 <= blocks);
for (size_t b = 0; b < blocks; b += 2)
{
if (prefetch)
@@ -384,6 +394,7 @@ KFR_SINTRIN ctrue_t radix4_pass(csize_t<16>, size_t blocks, csize_t<width>, cfal
{
CMT_ASSUME(blocks > 0);
constexpr static size_t prefetch_offset = width * 4;
+ DFT_ASSERT(2 <= blocks);
CMT_PRAGMA_CLANG(clang loop unroll_count(2))
for (size_t b = 0; b < blocks; b += 2)
{
@@ -415,6 +426,7 @@ KFR_SINTRIN ctrue_t radix4_pass(csize_t<4>, size_t blocks, csize_t<width>, cfals
{
constexpr static size_t prefetch_offset = width * 4;
CMT_ASSUME(blocks > 0);
+ DFT_ASSERT(4 <= blocks);
CMT_LOOP_NOUNROLL
for (size_t b = 0; b < blocks; b += 4)
{
@@ -445,7 +457,7 @@ struct fft_stage_impl : dft_stage<T>
protected:
constexpr static bool prefetch = true;
constexpr static bool aligned = false;
- constexpr static size_t width = platform<T>::vector_width;
+ constexpr static size_t width = fft_vector_width<T>;
virtual void do_initialize(size_t size) override final
{
@@ -457,7 +469,7 @@ protected:
{
const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
if (splitin)
- in = out;
+ in = out;
const size_t stg_size = this->stage_size;
CMT_ASSUME(stg_size >= 2048);
CMT_ASSUME(stg_size % 2048 == 0);
@@ -479,77 +491,77 @@ struct fft_final_stage_impl : dft_stage<T>
}
protected:
- constexpr static size_t width = platform<T>::vector_width;
- constexpr static bool is_even = cometa::is_even(ilog2(size));
- constexpr static bool use_br2 = !is_even;
- constexpr static bool aligned = false;
- constexpr static bool prefetch = splitin;
+ constexpr static size_t width = fft_vector_width<T>;
+ constexpr static bool is_even = cometa::is_even(ilog2(size));
+ constexpr static bool use_br2 = !is_even;
+ constexpr static bool aligned = false;
+ constexpr static bool prefetch = splitin;
- virtual void do_initialize(size_t total_size) override final
+ KFR_INTRIN void init_twiddles(csize_t<8>, size_t, cfalse_t, complex<T>*&) {}
+ KFR_INTRIN void init_twiddles(csize_t<4>, size_t, cfalse_t, complex<T>*&) {}
+
+ template <size_t N, bool pass_splitin>
+ KFR_INTRIN void init_twiddles(csize_t<N>, size_t total_size, cbool_t<pass_splitin>, complex<T>*& twiddle)
{
- complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- size_t stg_size = this->stage_size;
- while (stg_size > 4)
- {
- initialize_twiddles<T, width>(twiddle, stg_size, total_size, true);
- stg_size /= 4;
- }
+ constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width;
+ constexpr size_t pass_width = const_min(width, N / 4);
+ initialize_twiddles<T, pass_width>(twiddle, N, total_size, pass_split || pass_splitin);
+ init_twiddles(csize<N / 4>, total_size, cbool<pass_split>, twiddle);
}
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final
+ virtual void do_initialize(size_t total_size) override final
{
- constexpr bool is_double = sizeof(T) == 8;
- constexpr size_t final_size = is_even ? (is_double ? 4 : 16) : (is_double ? 8 : 32);
- const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- final_pass(csize_t<final_size>(), out, in, twiddle);
+ complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
+ init_twiddles(csize<size>, total_size, cbool<splitin>, twiddle);
}
- KFR_INTRIN void final_pass(csize_t<8>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
+ virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override
{
- radix4_pass(512, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
- radix4_pass(128, 4, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(32, 16, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(csize_t<8>(), 64, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
+ final_stage(csize<size>, 1, cbool<splitin>, out, in, twiddle);
}
- KFR_INTRIN void final_pass(csize_t<32>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
+ // KFR_INTRIN void final_stage(csize_t<32>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
+ // const complex<T>*& twiddle)
+ // {
+ // radix4_pass(csize_t<32>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
+ // cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ // }
+ //
+ // KFR_INTRIN void final_stage(csize_t<16>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
+ // const complex<T>*& twiddle)
+ // {
+ // radix4_pass(csize_t<16>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
+ // cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ // }
+
+ KFR_INTRIN void final_stage(csize_t<8>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
+ const complex<T>*& twiddle)
{
- radix4_pass(512, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
- radix4_pass(128, 4, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(csize_t<32>(), 16, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
+ radix4_pass(csize_t<8>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
}
- KFR_INTRIN void final_pass(csize_t<4>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
+ KFR_INTRIN void final_stage(csize_t<4>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
+ const complex<T>*& twiddle)
{
- radix4_pass(1024, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
- radix4_pass(256, 4, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(64, 16, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(16, 64, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(csize_t<4>(), 256, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
+ radix4_pass(csize_t<4>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
}
- KFR_INTRIN void final_pass(csize_t<16>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
+ template <size_t N, bool pass_splitin>
+ KFR_INTRIN void final_stage(csize_t<N>, size_t invN, cbool_t<pass_splitin>, complex<T>* out,
+ const complex<T>* in, const complex<T>*& twiddle)
{
- radix4_pass(1024, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
- radix4_pass(256, 4, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(64, 16, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(csize_t<16>(), 64, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ static_assert(N > 8, "");
+ constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width;
+ constexpr size_t pass_width = const_min(width, N / 4);
+ static_assert(pass_width == width || (pass_split == pass_splitin), "");
+ static_assert(pass_width <= N / 4, "");
+ radix4_pass(N, invN, csize_t<pass_width>(), cbool<pass_split>, cbool_t<pass_splitin>(),
+ cbool_t<use_br2>(), cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in,
+ twiddle);
+ final_stage(csize<N / 4>, invN * 4, cbool<pass_split>, out, out, twiddle);
}
};
@@ -581,6 +593,7 @@ template <typename T, bool inverse>
struct fft_specialization<T, 1, inverse> : dft_stage<T>
{
fft_specialization(size_t) {}
+
protected:
constexpr static bool aligned = false;
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
@@ -595,6 +608,7 @@ template <typename T, bool inverse>
struct fft_specialization<T, 2, inverse> : dft_stage<T>
{
fft_specialization(size_t) {}
+
protected:
constexpr static bool aligned = false;
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
@@ -610,6 +624,7 @@ template <typename T, bool inverse>
struct fft_specialization<T, 3, inverse> : dft_stage<T>
{
fft_specialization(size_t) {}
+
protected:
constexpr static bool aligned = false;
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
@@ -624,6 +639,7 @@ template <typename T, bool inverse>
struct fft_specialization<T, 4, inverse> : dft_stage<T>
{
fft_specialization(size_t) {}
+
protected:
constexpr static bool aligned = false;
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
@@ -638,6 +654,7 @@ template <typename T, bool inverse>
struct fft_specialization<T, 5, inverse> : dft_stage<T>
{
fft_specialization(size_t) {}
+
protected:
constexpr static bool aligned = false;
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
@@ -652,6 +669,7 @@ template <typename T, bool inverse>
struct fft_specialization<T, 6, inverse> : dft_stage<T>
{
fft_specialization(size_t) {}
+
protected:
constexpr static bool aligned = false;
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
@@ -671,7 +689,7 @@ struct fft_specialization<T, 7, inverse> : dft_stage<T>
protected:
constexpr static bool aligned = false;
- constexpr static size_t width = platform<T>::vector_width;
+ constexpr static size_t width = fft_vector_width<T>;
constexpr static bool use_br2 = true;
constexpr static bool prefetch = false;
constexpr static bool is_double = sizeof(T) == 8;
@@ -716,6 +734,7 @@ template <bool inverse>
struct fft_specialization<float, 8, inverse> : dft_stage<float>
{
fft_specialization(size_t) { this->temp_size = sizeof(complex<float>) * 256; }
+
protected:
virtual void do_execute(complex<float>* out, const complex<float>* in, u8* temp) override final
{
@@ -748,48 +767,16 @@ protected:
};
template <bool inverse>
-struct fft_specialization<double, 8, inverse> : dft_stage<double>
+struct fft_specialization<double, 8, inverse> : fft_final_stage_impl<double, false, 256, inverse>
{
using T = double;
- fft_specialization(size_t)
- {
- this->stage_size = 256;
- this->data_size = align_up(sizeof(complex<T>) * 256 * 3 / 2, platform<>::native_cache_alignment);
- }
-
-protected:
- constexpr static bool aligned = false;
- constexpr static size_t width = platform<T>::vector_width;
- constexpr static bool use_br2 = false;
- constexpr static bool prefetch = false;
- constexpr static size_t split_format = true;
-
- virtual void do_initialize(size_t total_size) override final
- {
- complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- initialize_twiddles<T, width>(twiddle, 256, total_size, split_format);
- initialize_twiddles<T, width>(twiddle, 64, total_size, split_format);
- initialize_twiddles<T, width>(twiddle, 16, total_size, split_format);
- }
+ using fft_final_stage_impl<double, false, 256, inverse>::fft_final_stage_impl;
virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final
{
- const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- final_pass(csize_t<4>(), out, in, twiddle);
+ fft_final_stage_impl<double, false, 256, inverse>::do_execute(out, in, nullptr);
fft_reorder(out, csize_t<8>());
}
-
- KFR_INTRIN void final_pass(csize_t<4>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
- {
- radix4_pass(256, 1, csize_t<width>(), ctrue, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
- radix4_pass(64, 4, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(16, 16, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(csize_t<4>(), 64, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- }
};
template <typename T, bool splitin, bool is_even>
@@ -816,14 +803,14 @@ struct fft_specialization_t
template <bool inverse>
using type = internal::fft_specialization<T, log2n, inverse>;
};
-}
+} // namespace internal
namespace dft_type
{
constexpr cbools_t<true, true> both{};
constexpr cbools_t<true, false> direct{};
constexpr cbools_t<false, true> inverse{};
-}
+} // namespace dft_type
template <typename T>
struct dft_plan
@@ -1029,7 +1016,7 @@ struct dft_plan_real : dft_plan<T>
{
using namespace internal;
- constexpr size_t width = platform<T>::vector_width * 2;
+ constexpr size_t width = fft_vector_width<T> * 2;
block_process(size / 4, csizes_t<width, 1>(), [=](size_t i, auto w) {
constexpr size_t width = val_of(decltype(w)());
@@ -1078,13 +1065,13 @@ private:
size_t csize =
this->size / 2; // const size_t causes internal compiler error: in tsubst_copy in GCC 5.2
- constexpr size_t width = platform<T>::vector_width * 2;
- const cvec<T, 1> dc = cread<1>(out);
- const size_t count = csize / 2;
+ constexpr size_t width = fft_vector_width<T> * 2;
+ const cvec<T, 1> dc = cread<1>(out);
+ const size_t count = csize / 2;
block_process(count, csizes_t<width, 1>(), [=](size_t i, auto w) {
- constexpr size_t width = val_of(decltype(w)());
- constexpr size_t widthm1 = width - 1;
+ constexpr size_t width = val_of(decltype(w)());
+ constexpr size_t widthm1 = width - 1;
const cvec<T, width> tw = cread<width>(rtwiddle.data() + i);
const cvec<T, width> fpk = cread<width>(out + i);
const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(out + csize - i - widthm1)));
@@ -1097,7 +1084,7 @@ private:
});
{
- size_t k = csize / 2;
+ size_t k = csize / 2;
const cvec<T, 1> fpk = cread<1>(out + k);
const cvec<T, 1> fpnk = negodd(fpk);
cwrite<1>(out + k, fpnk);
@@ -1129,13 +1116,13 @@ private:
dc = pack(in[0].real() + in[0].imag(), in[0].real() - in[0].imag());
}
- constexpr size_t width = platform<T>::vector_width * 2;
+ constexpr size_t width = fft_vector_width<T> * 2;
const size_t count = csize / 2;
block_process(count, csizes_t<width, 1>(), [=](size_t i, auto w) {
i++;
- constexpr size_t width = val_of(decltype(w)());
- constexpr size_t widthm1 = width - 1;
+ constexpr size_t width = val_of(decltype(w)());
+ constexpr size_t widthm1 = width - 1;
const cvec<T, width> tw = cread<width>(rtwiddle.data() + i);
const cvec<T, width> fpk = cread<width>(in + i);
const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(in + csize - i - widthm1)));
@@ -1148,7 +1135,7 @@ private:
});
{
- size_t k = csize / 2;
+ size_t k = csize / 2;
const cvec<T, 1> fpk = cread<1>(in + k);
const cvec<T, 1> fpnk = 2 * negodd(fpk);
cwrite<1>(out + k, fpnk);
@@ -1195,7 +1182,7 @@ void fft_multiply_accumulate(univector<complex<T>, Tag1>& dest, const univector<
if (fmt == dft_pack_format::Perm)
dest[0] = f0;
}
-}
+} // namespace kfr
CMT_PRAGMA_GNU(GCC diagnostic pop)
diff --git a/include/kfr/testo/assert.hpp b/include/kfr/testo/assert.hpp
@@ -50,9 +50,7 @@ bool check_assertion(const half_comparison<L>& comparison, const char* expr, con
return result;
}
-#if defined(TESTO_ASSERTION_ON) || !(defined(NDEBUG) || defined(TESTO_ASSERTION_OFF))
-
-#define TESTO_ASSERT(...) \
+#define TESTO_ASSERT_ACTIVE(...) \
do \
{ \
if (!::testo::check_assertion(::testo::make_comparison() <= __VA_ARGS__, #__VA_ARGS__, __FILE__, \
@@ -60,6 +58,10 @@ bool check_assertion(const half_comparison<L>& comparison, const char* expr, con
TESTO_BREAKPOINT; \
} while (0)
+#if defined(TESTO_ASSERTION_ON) || !(defined(NDEBUG) || defined(TESTO_ASSERTION_OFF))
+
+#define TESTO_ASSERT TESTO_ASSERT_ACTIVE
+
#else
#define TESTO_ASSERT(...) \
do \
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
@@ -25,7 +25,7 @@ add_definitions(-DKFR_TESTING=1)
include_directories(../include)
if (NOT ARM)
- if(MSVC)
+ if(MSVC AND NOT CLANG)
add_executable(multiarch multiarch.cpp multiarch_fir_sse2.cpp multiarch_fir_avx.cpp ${KFR_SRC})
set_source_files_properties(multiarch_fir_sse2.cpp PROPERTIES COMPILE_FLAGS /arch:SSE2)
set_source_files_properties(multiarch_fir_avx.cpp PROPERTIES COMPILE_FLAGS /arch:AVX)
diff --git a/tests/expression_test.cpp b/tests/expression_test.cpp
@@ -114,7 +114,7 @@ constexpr inline size_t fast_range_sum(size_t stop) { return stop * (stop + 1) /
TEST(partition)
{
{
- univector<double, 400> output = zeros();
+ univector<double, 385> output = zeros();
auto result = partition(output, counter(), 5, 1);
CHECK(result.count == 5);
CHECK(result.chunk_size == 80);
@@ -128,11 +128,11 @@ TEST(partition)
result(3);
CHECK(sum(output) >= fast_range_sum(320 - 1));
result(4);
- CHECK(sum(output) == fast_range_sum(400 - 1));
+ CHECK(sum(output) == fast_range_sum(385 - 1));
}
{
- univector<double, 400> output = zeros();
+ univector<double, 385> output = zeros();
auto result = partition(output, counter(), 5, 160);
CHECK(result.count == 3);
CHECK(result.chunk_size == 160);
@@ -142,7 +142,7 @@ TEST(partition)
result(1);
CHECK(sum(output) >= fast_range_sum(320 - 1));
result(2);
- CHECK(sum(output) == fast_range_sum(400 - 1));
+ CHECK(sum(output) == fast_range_sum(385 - 1));
}
}