mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
Compare commits
4 Commits
ci-android
...
b2916
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b43272afa2 | ||
|
|
0fc1e820a9 | ||
|
|
82ca83db3c | ||
|
|
f4bd8b3d26 |
@@ -227,20 +227,20 @@ effectiveStdenv.mkDerivation (
|
||||
)
|
||||
]
|
||||
++ optionals useRocm [
|
||||
(cmakeFeature "CMAKE_C_COMPILER" "hipcc")
|
||||
(cmakeFeature "CMAKE_CXX_COMPILER" "hipcc")
|
||||
|
||||
# Build all targets supported by rocBLAS. When updating search for TARGET_LIST_ROCM
|
||||
# in https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/CMakeLists.txt
|
||||
# and select the line that matches the current nixpkgs version of rocBLAS.
|
||||
# Should likely use `rocmPackages.clr.gpuTargets`.
|
||||
"-DAMDGPU_TARGETS=gfx803;gfx900;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102"
|
||||
(cmakeFeature "CMAKE_HIP_COMPILER" "${rocmPackages.llvm.clang}/bin/clang")
|
||||
(cmakeFeature "CMAKE_HIP_ARCHITECTURES" (builtins.concatStringsSep ";" rocmPackages.clr.gpuTargets))
|
||||
]
|
||||
++ optionals useMetalKit [
|
||||
(lib.cmakeFeature "CMAKE_C_FLAGS" "-D__ARM_FEATURE_DOTPROD=1")
|
||||
(cmakeBool "LLAMA_METAL_EMBED_LIBRARY" (!precompileMetalShaders))
|
||||
];
|
||||
|
||||
# Environment variables needed for ROCm
|
||||
env = optionals useRocm {
|
||||
ROCM_PATH = "${rocmPackages.clr}";
|
||||
HIP_DEVICE_LIB_PATH = "${rocmPackages.rocm-device-libs}/amdgcn/bitcode";
|
||||
};
|
||||
|
||||
# TODO(SomeoneSerge): It's better to add proper install targets at the CMake level,
|
||||
# if they haven't been added yet.
|
||||
postInstall = ''
|
||||
|
||||
58
.github/workflows/build.yml
vendored
58
.github/workflows/build.yml
vendored
@@ -392,6 +392,33 @@ jobs:
|
||||
cmake -DLLAMA_VULKAN=ON ..
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
|
||||
ubuntu-22-cmake-hip:
|
||||
runs-on: ubuntu-22.04
|
||||
container: rocm/dev-ubuntu-22.04:6.0.2
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev
|
||||
|
||||
- name: Build with native CMake HIP support
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -S . -DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" -DLLAMA_HIPBLAS=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
- name: Build with legacy HIP support
|
||||
id: cmake_build_legacy_hip
|
||||
run: |
|
||||
cmake -B build2 -S . -DCMAKE_C_COMPILER=hipcc -DCMAKE_CXX_COMPILER=hipcc -DLLAMA_HIPBLAS=ON
|
||||
cmake --build build2 --config Release -j $(nproc)
|
||||
|
||||
ubuntu-22-cmake-sycl:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
@@ -989,6 +1016,37 @@ jobs:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip
|
||||
name: llama-bin-win-sycl-x64.zip
|
||||
|
||||
windows-latest-cmake-hip:
|
||||
runs-on: windows-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install
|
||||
id: depends
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "Downloading AMD HIP SDK Installer"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
write-host "Installing AMD HIP SDK"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP SDK installation"
|
||||
|
||||
- name: Verify ROCm
|
||||
id: verify
|
||||
run: |
|
||||
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
|
||||
cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DLLAMA_HIPBLAS=ON
|
||||
cmake --build build --config Release
|
||||
|
||||
ios-xcode-build:
|
||||
runs-on: macos-latest
|
||||
|
||||
|
||||
@@ -555,16 +555,37 @@ if (LLAMA_VULKAN)
|
||||
endif()
|
||||
|
||||
if (LLAMA_HIPBLAS)
|
||||
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
|
||||
if ($ENV{ROCM_PATH})
|
||||
set(ROCM_PATH $ENV{ROCM_PATH})
|
||||
else()
|
||||
set(ROCM_PATH /opt/rocm)
|
||||
endif()
|
||||
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
|
||||
|
||||
if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
|
||||
message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
|
||||
# CMake on Windows doesn't support the HIP language yet
|
||||
if(WIN32)
|
||||
set(CXX_IS_HIPCC TRUE)
|
||||
else()
|
||||
string(REGEX MATCH "hipcc(\.bat)?$" CXX_IS_HIPCC "${CMAKE_CXX_COMPILER}")
|
||||
endif()
|
||||
|
||||
if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
|
||||
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
|
||||
endif()
|
||||
if(CXX_IS_HIPCC)
|
||||
if(LINUX)
|
||||
if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
|
||||
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
|
||||
endif()
|
||||
|
||||
message(WARNING "Setting hipcc as the C++ compiler is legacy behavior."
|
||||
" Prefer setting the HIP compiler directly. See README for details.")
|
||||
endif()
|
||||
else()
|
||||
# Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
|
||||
if(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
||||
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_ARGETS})
|
||||
endif()
|
||||
cmake_minimum_required(VERSION 3.21)
|
||||
enable_language(HIP)
|
||||
endif()
|
||||
find_package(hip REQUIRED)
|
||||
find_package(hipblas REQUIRED)
|
||||
find_package(rocblas REQUIRED)
|
||||
@@ -598,13 +619,18 @@ if (LLAMA_HIPBLAS)
|
||||
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
|
||||
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
|
||||
|
||||
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
|
||||
if (CXX_IS_HIPCC)
|
||||
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} hip::device)
|
||||
else()
|
||||
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP)
|
||||
endif()
|
||||
|
||||
if (LLAMA_STATIC)
|
||||
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
|
||||
endif()
|
||||
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} PUBLIC hip::host roc::rocblas roc::hipblas)
|
||||
endif()
|
||||
|
||||
if (LLAMA_SYCL)
|
||||
|
||||
6
Makefile
6
Makefile
@@ -560,10 +560,10 @@ endif # LLAMA_VULKAN
|
||||
ifdef LLAMA_HIPBLAS
|
||||
ifeq ($(wildcard /opt/rocm),)
|
||||
ROCM_PATH ?= /usr
|
||||
GPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
|
||||
AMDGPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
|
||||
else
|
||||
ROCM_PATH ?= /opt/rocm
|
||||
GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
|
||||
AMDGPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
|
||||
endif
|
||||
HIPCC ?= $(CCACHE) $(ROCM_PATH)/bin/hipcc
|
||||
LLAMA_CUDA_DMMV_X ?= 32
|
||||
@@ -575,7 +575,7 @@ ifdef LLAMA_HIP_UMA
|
||||
endif # LLAMA_HIP_UMA
|
||||
MK_LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
|
||||
MK_LDFLAGS += -lhipblas -lamdhip64 -lrocblas
|
||||
HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
|
||||
HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS))
|
||||
HIPFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
|
||||
HIPFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
|
||||
HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
|
||||
|
||||
25
README.md
25
README.md
@@ -528,13 +528,28 @@ Building the program with BLAS support may lead to some performance improvements
|
||||
```
|
||||
- Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU):
|
||||
```bash
|
||||
CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ \
|
||||
cmake -B build -DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
|
||||
cmake -S . -B build -DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||
&& cmake --build build --config Release -- -j 16
|
||||
```
|
||||
On Linux it is also possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting `-DLLAMA_HIP_UMA=ON`.
|
||||
However, this hurts performance for non-integrated GPUs (but enables working with integrated GPUs).
|
||||
|
||||
Note that if you get the following error:
|
||||
```
|
||||
clang: error: cannot find ROCm device library; provide its path via '--rocm-path' or '--rocm-device-lib-path', or pass '-nogpulib' to build without ROCm device library
|
||||
```
|
||||
Try searching for a directory under `HIP_PATH` that contains the file
|
||||
`oclc_abi_version_400.bc`. Then, add the following to the start of the
|
||||
command: `HIP_DEVICE_LIB_PATH=<directory-you-just-found>`, so something
|
||||
like:
|
||||
```bash
|
||||
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -p)" \
|
||||
HIP_DEVICE_LIB_PATH=<directory-you-just-found> \
|
||||
cmake -S . -B build -DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||
&& cmake --build build -- -j 16
|
||||
```
|
||||
|
||||
- Using `make` (example for target gfx1030, build with 16 CPU threads):
|
||||
```bash
|
||||
make -j16 LLAMA_HIPBLAS=1 LLAMA_HIP_UMA=1 AMDGPU_TARGETS=gfx1030
|
||||
@@ -543,10 +558,8 @@ Building the program with BLAS support may lead to some performance improvements
|
||||
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU):
|
||||
```bash
|
||||
set PATH=%HIP_PATH%\bin;%PATH%
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -G Ninja -DAMDGPU_TARGETS=gfx1100 -DLLAMA_HIPBLAS=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release ..
|
||||
cmake --build .
|
||||
cmake -S . -B build -G Ninja -DAMDGPU_TARGETS=gfx1100 -DLLAMA_HIPBLAS=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
|
||||
Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`.
|
||||
|
||||
@@ -12,17 +12,15 @@ cmake_minimum_required(VERSION 3.22.1)
|
||||
# build script scope).
|
||||
project("llama-android")
|
||||
|
||||
#include(FetchContent)
|
||||
#FetchContent_Declare(
|
||||
# llama
|
||||
# GIT_REPOSITORY https://github.com/ggerganov/llama.cpp
|
||||
# GIT_TAG ci-android
|
||||
#)
|
||||
#
|
||||
## Also provides "common"
|
||||
#FetchContent_MakeAvailable(llama)
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
llama
|
||||
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp
|
||||
GIT_TAG master
|
||||
)
|
||||
|
||||
add_subdirectory(../../../../../../ please-work)
|
||||
# Also provides "common"
|
||||
FetchContent_MakeAvailable(llama)
|
||||
|
||||
# Creates and names a library, sets it as either STATIC
|
||||
# or SHARED, and provides the relative paths to its source code.
|
||||
|
||||
@@ -56,6 +56,10 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
395
ggml-cuda/fattn-tile-f16.cu
Normal file
395
ggml-cuda/fattn-tile-f16.cu
Normal file
@@ -0,0 +1,395 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-tile-f16.cuh"
|
||||
|
||||
#define FATTN_KQ_STRIDE_TILE_F16 64
|
||||
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_tile_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if FP16_AVAILABLE
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
half slopeh = __float2half(1.0f);
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const uint32_t h = blockIdx.y;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slopeh = __float2half(powf(base, exph));
|
||||
}
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
|
||||
__shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16];
|
||||
half2 * KQ2 = (half2 *) KQ;
|
||||
|
||||
__shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts.
|
||||
|
||||
half kqmax[ncols/nwarps];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
kqmax[j0/nwarps] = -HALF_MAX_HALF;
|
||||
}
|
||||
half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}};
|
||||
|
||||
half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
|
||||
|
||||
// Convert Q to half2 and store in registers:
|
||||
__shared__ half2 Q_h2[ncols][D/2];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
|
||||
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
|
||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
half kqmax_new[ncols/nwarps];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||
kqmax_new[j] = kqmax[j];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}};
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) {
|
||||
half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE];
|
||||
half2 Q_k[ncols/nwarps];
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||
|
||||
Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||
|
||||
half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||
|
||||
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
|
||||
|
||||
KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
|
||||
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]));
|
||||
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]);
|
||||
const half2 val = h2exp(diff);
|
||||
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val;
|
||||
KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) {
|
||||
const int k = k0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) {
|
||||
half2 V_k[(D/2)/WARP_SIZE][2];
|
||||
half2 KQ_k[ncols/nwarps];
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i];
|
||||
V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]);
|
||||
VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||
|
||||
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
|
||||
kqsum_j = warp_reduce_sum(kqsum_j);
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||
const int i0 = i00 + 2*threadIdx.x;
|
||||
|
||||
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||
if (parallel_blocks == 1) {
|
||||
dst_val /= __half2half2(kqsum_j);
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && threadIdx.x == 0) {
|
||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f16(
|
||||
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
||||
ggml_cuda_pool & pool, cudaStream_t main_stream
|
||||
) {
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
||||
}
|
||||
|
||||
constexpr int nwarps = 8;
|
||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
||||
const int shmem = 0;
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const uint32_t n_head = Q->ne[2];
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data,
|
||||
(const char *) K->data,
|
||||
(const char *) V->data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if (parallel_blocks == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
||||
const int shmem_combine = 0;
|
||||
|
||||
flash_attn_combine_results<D, parallel_blocks>
|
||||
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
|
||||
if (Q->ne[1] <= 16) {
|
||||
constexpr int cols_per_block = 16;
|
||||
constexpr int parallel_blocks = 4;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 128:
|
||||
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 4;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 128:
|
||||
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 1;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 128:
|
||||
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
3
ggml-cuda/fattn-tile-f16.cuh
Normal file
3
ggml-cuda/fattn-tile-f16.cuh
Normal file
@@ -0,0 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
393
ggml-cuda/fattn-tile-f32.cu
Normal file
393
ggml-cuda/fattn-tile-f32.cu
Normal file
@@ -0,0 +1,393 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-tile-f32.cuh"
|
||||
|
||||
#define FATTN_KQ_STRIDE_TILE_F32 32
|
||||
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_tile_ext_f32(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const uint32_t h = blockIdx.y;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = powf(base, exph);
|
||||
}
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
|
||||
__shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
|
||||
|
||||
__shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts.
|
||||
float2 * KV_tmp2 = (float2 *) KV_tmp;
|
||||
|
||||
float kqmax[ncols/nwarps];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
kqmax[j0/nwarps] = -FLT_MAX/2.0f;
|
||||
}
|
||||
float kqsum[ncols/nwarps] = {0.0f};
|
||||
|
||||
float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
|
||||
|
||||
// Convert Q to half2 and store in registers:
|
||||
__shared__ float Q_f[ncols][D];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
|
||||
float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x];
|
||||
Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
|
||||
Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
|
||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
float kqmax_new[ncols/nwarps];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||
kqmax_new[j] = kqmax[j];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
|
||||
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
|
||||
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
|
||||
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}};
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ = 0; k_KQ < D; ++k_KQ) {
|
||||
float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
|
||||
float Q_k[ncols/nwarps];
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||
|
||||
Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||
|
||||
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||
|
||||
KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
|
||||
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
|
||||
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
|
||||
|
||||
float kqsum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps];
|
||||
const float val = expf(diff);
|
||||
kqsum_add += val;
|
||||
KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val;
|
||||
}
|
||||
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
|
||||
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) {
|
||||
const int k = k0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
|
||||
KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
|
||||
float2 V_k[(D/2)/WARP_SIZE];
|
||||
float KQ_k[ncols/nwarps];
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps];
|
||||
VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||
|
||||
float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
||||
kqsum_j = warp_reduce_sum(kqsum_j);
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||
const int i0 = i00 + 2*threadIdx.x;
|
||||
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||
if (parallel_blocks == 1) {
|
||||
dst_val.x /= kqsum_j;
|
||||
dst_val.y /= kqsum_j;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && threadIdx.x == 0) {
|
||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f32(
|
||||
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
||||
ggml_cuda_pool & pool, cudaStream_t main_stream
|
||||
) {
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
||||
}
|
||||
|
||||
constexpr int nwarps = 8;
|
||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
||||
const int shmem = 0;
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const uint32_t n_head = Q->ne[2];
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data,
|
||||
(const char *) K->data,
|
||||
(const char *) V->data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if (parallel_blocks == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
||||
const int shmem_combine = 0;
|
||||
|
||||
flash_attn_combine_results<D, parallel_blocks>
|
||||
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
|
||||
if (Q->ne[1] <= 16) {
|
||||
constexpr int cols_per_block = 16;
|
||||
constexpr int parallel_blocks = 4;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 128:
|
||||
launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 4;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 128:
|
||||
launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 1;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 128:
|
||||
launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
3
ggml-cuda/fattn-tile-f32.cuh
Normal file
3
ggml-cuda/fattn-tile-f32.cuh
Normal file
@@ -0,0 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -57,7 +57,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const int h = blockIdx.y;
|
||||
const uint32_t h = blockIdx.y;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
@@ -232,11 +232,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && tid != 0) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
|
||||
}
|
||||
if (parallel_blocks != 1 && threadIdx.x < ncols) {
|
||||
dst_meta[(ic0 + threadIdx.x)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[threadIdx.x], kqsum[threadIdx.x]);
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
|
||||
@@ -56,7 +56,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const int h = blockIdx.y;
|
||||
const uint32_t h = blockIdx.y;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
@@ -221,11 +221,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && tid != 0) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
|
||||
}
|
||||
if (parallel_blocks != 1 && threadIdx.x < ncols) {
|
||||
dst_meta[(ic0 + threadIdx.x)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[threadIdx.x], kqsum[threadIdx.x]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-tile-f16.cuh"
|
||||
#include "fattn-tile-f32.cuh"
|
||||
#include "fattn-vec-f16.cuh"
|
||||
#include "fattn-vec-f32.cuh"
|
||||
#include "fattn.cuh"
|
||||
@@ -88,7 +90,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const int h = blockIdx.y;
|
||||
const uint32_t h = blockIdx.y;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
@@ -541,13 +543,31 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
if (precision == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!fast_fp16_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!fp16_mma_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
13
ggml-rpc.cpp
13
ggml-rpc.cpp
@@ -134,7 +134,13 @@ static bool set_no_delay(sockfd_t sockfd) {
|
||||
int flag = 1;
|
||||
// set TCP_NODELAY to disable Nagle's algorithm
|
||||
int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
|
||||
return ret >= 0;
|
||||
return ret == 0;
|
||||
}
|
||||
|
||||
static bool set_reuse_addr(sockfd_t sockfd) {
|
||||
int flag = 1;
|
||||
int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
|
||||
return ret == 0;
|
||||
}
|
||||
|
||||
static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
|
||||
@@ -181,7 +187,10 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
|
||||
if (sock == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!set_reuse_addr(sockfd)) {
|
||||
fprintf(stderr, "Failed to set SO_REUSEADDR\n");
|
||||
return nullptr;
|
||||
}
|
||||
struct sockaddr_in serv_addr;
|
||||
serv_addr.sin_family = AF_INET;
|
||||
serv_addr.sin_addr.s_addr = inet_addr(host);
|
||||
|
||||
6
ggml.c
6
ggml.c
@@ -2076,7 +2076,7 @@ inline static float ggml_silu_f32(float x) {
|
||||
return x/(1.0f + expf(-x));
|
||||
}
|
||||
|
||||
#if defined(__ARM_NEON) && defined(__aarch64__)
|
||||
#if defined(__ARM_NEON)
|
||||
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
@@ -2288,7 +2288,7 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
||||
for (; i + 3 < n; i += 4) {
|
||||
_mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
|
||||
}
|
||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||
#elif defined(__ARM_NEON)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
|
||||
}
|
||||
@@ -2335,7 +2335,7 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
|
||||
#endif
|
||||
sum += (ggml_float)_mm_cvtss_f32(val);
|
||||
}
|
||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||
#elif defined(__ARM_NEON)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
|
||||
vdupq_n_f32(max)));
|
||||
|
||||
@@ -12576,16 +12576,16 @@ struct llm_tokenizer_wpm {
|
||||
// to lowercase, pad chinese characters, pad punctuation
|
||||
std::string new_str = "";
|
||||
for (uint32_t code : cpts_nfd) {
|
||||
int type = unicode_cpt_type(code);
|
||||
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) {
|
||||
const codepoint_flags flags = unicode_cpt_flags(code);
|
||||
if (flags.is_accent_mark || flags.is_control) {
|
||||
continue;
|
||||
}
|
||||
code = unicode_tolower(code);
|
||||
if (type == CODEPOINT_TYPE_SEPARATOR) {
|
||||
if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ?
|
||||
code = ' ';
|
||||
}
|
||||
std::string s = unicode_cpt_to_utf8(code);
|
||||
if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) {
|
||||
if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) {
|
||||
new_str += " ";
|
||||
new_str += s;
|
||||
new_str += " ";
|
||||
|
||||
@@ -1,64 +1,134 @@
|
||||
import regex
|
||||
import ctypes
|
||||
import unicodedata
|
||||
|
||||
|
||||
def get_matches(regex_expr):
|
||||
regex_expr_compiled = regex.compile(regex_expr)
|
||||
unicode_ranges = []
|
||||
current_range = None
|
||||
|
||||
for codepoint in range(0x110000):
|
||||
char = chr(codepoint)
|
||||
if regex_expr_compiled.match(char):
|
||||
if current_range is None:
|
||||
current_range = [codepoint, codepoint]
|
||||
else:
|
||||
current_range[1] = codepoint
|
||||
elif current_range is not None:
|
||||
unicode_ranges.append(tuple(current_range))
|
||||
current_range = None
|
||||
|
||||
if current_range is not None:
|
||||
unicode_ranges.append(tuple(current_range))
|
||||
|
||||
return unicode_ranges
|
||||
class CoodepointFlags (ctypes.Structure):
|
||||
_fields_ = [ # see definition in unicode.h
|
||||
("is_undefined", ctypes.c_uint16, 1),
|
||||
("is_number", ctypes.c_uint16, 1), # regex: \p{N}
|
||||
("is_letter", ctypes.c_uint16, 1), # regex: \p{L}
|
||||
("is_separator", ctypes.c_uint16, 1), # regex: \p{Z}
|
||||
("is_accent_mark", ctypes.c_uint16, 1), # regex: \p{M}
|
||||
("is_punctuation", ctypes.c_uint16, 1), # regex: \p{P}
|
||||
("is_symbol", ctypes.c_uint16, 1), # regex: \p{S}
|
||||
("is_control", ctypes.c_uint16, 1), # regex: \p{C}
|
||||
]
|
||||
|
||||
|
||||
def print_cat(mode, cat, ranges):
|
||||
if mode == "range":
|
||||
print("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat)) # noqa: NP100
|
||||
if mode == "map":
|
||||
print("const std::map<uint32_t, uint32_t> unicode_map_{} = {{".format(cat)) # noqa: NP100
|
||||
for i, values in enumerate(ranges):
|
||||
end = ",\n" if (i % 4 == 3 or i + 1 == len(ranges)) else ", "
|
||||
values = ["0x%08X" % value for value in values]
|
||||
print("{" + ", ".join(values) + "}", end=end) # noqa: NP100
|
||||
print("};") # noqa: NP100
|
||||
print("") # noqa: NP100
|
||||
assert (ctypes.sizeof(CoodepointFlags) == 2)
|
||||
|
||||
|
||||
print_cat("range", "number", get_matches(r'\p{N}'))
|
||||
print_cat("range", "letter", get_matches(r'\p{L}'))
|
||||
print_cat("range", "separator", get_matches(r'\p{Z}'))
|
||||
print_cat("range", "accent_mark", get_matches(r'\p{M}'))
|
||||
print_cat("range", "punctuation", get_matches(r'\p{P}'))
|
||||
print_cat("range", "symbol", get_matches(r'\p{S}'))
|
||||
print_cat("range", "control", get_matches(r'\p{C}'))
|
||||
MAX_CODEPOINTS = 0x110000
|
||||
|
||||
print_cat("range", "whitespace", get_matches(r'\s'))
|
||||
regex_number = regex.compile(r'\p{N}')
|
||||
regex_letter = regex.compile(r'\p{L}')
|
||||
regex_separator = regex.compile(r'\p{Z}')
|
||||
regex_accent_mark = regex.compile(r'\p{M}')
|
||||
regex_punctuation = regex.compile(r'\p{P}')
|
||||
regex_symbol = regex.compile(r'\p{S}')
|
||||
regex_control = regex.compile(r'\p{C}')
|
||||
regex_whitespace = regex.compile(r'\s')
|
||||
|
||||
codepoint_flags = (CoodepointFlags * MAX_CODEPOINTS)()
|
||||
table_whitespace = []
|
||||
table_lowercase = []
|
||||
table_uppercase = []
|
||||
table_nfd = []
|
||||
|
||||
map_lowercase = []
|
||||
map_uppercase = []
|
||||
for codepoint in range(0x110000):
|
||||
for codepoint in range(MAX_CODEPOINTS):
|
||||
# convert codepoint to unicode character
|
||||
char = chr(codepoint)
|
||||
|
||||
# regex categories
|
||||
flags = codepoint_flags[codepoint]
|
||||
flags.is_number = bool(regex_number.match(char))
|
||||
flags.is_letter = bool(regex_letter.match(char))
|
||||
flags.is_separator = bool(regex_separator.match(char))
|
||||
flags.is_accent_mark = bool(regex_accent_mark.match(char))
|
||||
flags.is_punctuation = bool(regex_punctuation.match(char))
|
||||
flags.is_symbol = bool(regex_symbol.match(char))
|
||||
flags.is_control = bool(regex_control.match(char))
|
||||
flags.is_undefined = bytes(flags)[0] == 0
|
||||
assert (not flags.is_undefined)
|
||||
|
||||
# whitespaces
|
||||
if bool(regex_whitespace.match(char)):
|
||||
table_whitespace.append(codepoint)
|
||||
|
||||
# lowercase conversion
|
||||
lower = ord(char.lower()[0])
|
||||
upper = ord(char.upper()[0])
|
||||
if codepoint != lower:
|
||||
map_lowercase.append((codepoint, lower))
|
||||
table_lowercase.append((codepoint, lower))
|
||||
|
||||
# uppercase conversion
|
||||
upper = ord(char.upper()[0])
|
||||
if codepoint != upper:
|
||||
map_uppercase.append((codepoint, upper))
|
||||
print_cat("map", "lowercase", map_lowercase)
|
||||
print_cat("map", "uppercase", map_uppercase)
|
||||
table_uppercase.append((codepoint, upper))
|
||||
|
||||
# NFD normalization
|
||||
norm = ord(unicodedata.normalize('NFD', char)[0])
|
||||
if codepoint != norm:
|
||||
table_nfd.append((codepoint, norm))
|
||||
|
||||
|
||||
# TODO: generate unicode_map_nfd
|
||||
# group ranges with same flags
|
||||
ranges_flags = [(0, codepoint_flags[0])] # start, flags
|
||||
for codepoint, flags in enumerate(codepoint_flags):
|
||||
if bytes(flags) != bytes(ranges_flags[-1][1]):
|
||||
ranges_flags.append((codepoint, flags))
|
||||
ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))
|
||||
|
||||
|
||||
# group ranges with same nfd
|
||||
ranges_nfd = [(0, 0, 0)] # start, last, nfd
|
||||
for codepoint, norm in table_nfd:
|
||||
start = ranges_nfd[-1][0]
|
||||
if ranges_nfd[-1] != (start, codepoint - 1, norm):
|
||||
ranges_nfd.append(None)
|
||||
start = codepoint
|
||||
ranges_nfd[-1] = (start, codepoint, norm)
|
||||
|
||||
|
||||
# Generate 'unicode-data.cpp'
|
||||
|
||||
|
||||
def out(line=""):
|
||||
print(line, end='\n') # noqa
|
||||
|
||||
|
||||
out("""\
|
||||
// generated with scripts/gen-unicode-data.py
|
||||
|
||||
#include "unicode-data.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
""")
|
||||
|
||||
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
|
||||
for codepoint, flags in ranges_flags:
|
||||
flags = int.from_bytes(bytes(flags), "little")
|
||||
out("{0x%06X, 0x%04X}," % (codepoint, flags))
|
||||
out("};\n")
|
||||
|
||||
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
|
||||
out(", ".join("0x%06X" % cpt for cpt in table_whitespace))
|
||||
out("};\n")
|
||||
|
||||
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
|
||||
for tuple in table_lowercase:
|
||||
out("{0x%06X, 0x%06X}," % tuple)
|
||||
out("};\n")
|
||||
|
||||
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase = {")
|
||||
for tuple in table_uppercase:
|
||||
out("{0x%06X, 0x%06X}," % tuple)
|
||||
out("};\n")
|
||||
|
||||
out("const std::vector<range_nfd> unicode_ranges_nfd = { // start, last, nfd")
|
||||
for triple in ranges_nfd:
|
||||
out("{0x%06X, 0x%06X, 0x%06X}," % triple)
|
||||
out("};\n")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Test libllama tokenizer == AutoTokenizer.
|
||||
# Brute force random tokens/text generation.
|
||||
# Brute force random words/text generation.
|
||||
#
|
||||
# Sample usage:
|
||||
#
|
||||
@@ -12,10 +12,10 @@ import argparse
|
||||
import subprocess
|
||||
import random
|
||||
|
||||
from typing import Iterator
|
||||
from typing import Callable, Iterator
|
||||
|
||||
import cffi
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger("test-tokenizer-random-bpe")
|
||||
|
||||
@@ -145,28 +145,35 @@ def generator_custom_text() -> Iterator[str]:
|
||||
def generator_custom_text_edge_cases() -> Iterator[str]:
|
||||
"""Edge cases found while debugging"""
|
||||
yield from [
|
||||
'\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
|
||||
'¼-a', # unicode_ranges_digit, 0x00BC
|
||||
'½-a', # unicode_ranges_digit, 0x00BD
|
||||
'¾-a', # unicode_ranges_digit, 0x00BE
|
||||
'a 〇b', # unicode_ranges_digit, 0x3007
|
||||
'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
|
||||
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
|
||||
'<s>a' # TODO: Phi-3 fail
|
||||
'\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
|
||||
'¼-a', # unicode_ranges_digit, 0x00BC
|
||||
'½-a', # unicode_ranges_digit, 0x00BD
|
||||
'¾-a', # unicode_ranges_digit, 0x00BE
|
||||
'a 〇b', # unicode_ranges_digit, 0x3007
|
||||
'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
|
||||
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
|
||||
'Cửa Việt', # llama-3, ignore_merges = true
|
||||
'<s>a', # TODO: Phi-3 fail
|
||||
'a\na', # TODO: Bert fail
|
||||
]
|
||||
|
||||
|
||||
def generator_random_chars(iterations = 100) -> Iterator[str]:
|
||||
def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
|
||||
"""Brute force check all vocab words"""
|
||||
yield from vocab
|
||||
|
||||
|
||||
def generator_random_chars(iterations=100) -> Iterator[str]:
|
||||
"""Brute force random text with simple characters"""
|
||||
|
||||
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
||||
CHARS = list(set("""
|
||||
CHARS = list(sorted(set("""
|
||||
ABCDEFGHIJKLMNOPQRSTUVWXYZ
|
||||
abcdefghijklmnopqrstuvwxyz
|
||||
ÁÉÍÓÚÀÈÌÒÙÂÊÎÔÛÄËÏÖÜ
|
||||
áéíóúàèìòùâêîôûäëïöü
|
||||
.-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_
|
||||
"""))
|
||||
""")))
|
||||
|
||||
rand = random.Random()
|
||||
for m in range(iterations):
|
||||
@@ -181,13 +188,13 @@ def generator_random_chars(iterations = 100) -> Iterator[str]:
|
||||
yield "".join(text)
|
||||
|
||||
|
||||
def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
|
||||
def generator_random_vocab_chars(vocab: list[str], iterations=100) -> Iterator[str]:
|
||||
"""Brute force random text with vocab characters"""
|
||||
|
||||
vocab_ids = list(tokenizer.vocab.values())
|
||||
vocab_text = tokenizer.decode(vocab_ids, skip_special_tokens=True)
|
||||
vocab_chars = list(set(vocab_text))
|
||||
del vocab_ids, vocab_text
|
||||
vocab_chars = set()
|
||||
for word in vocab:
|
||||
vocab_chars.update(word)
|
||||
vocab_chars = list(sorted(vocab_chars))
|
||||
|
||||
rand = random.Random()
|
||||
for m in range(iterations):
|
||||
@@ -196,19 +203,11 @@ def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations
|
||||
yield "".join(text)
|
||||
|
||||
|
||||
def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
|
||||
"""Brute force random text from vocab tokens"""
|
||||
def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[str]:
|
||||
"""Brute force random text from vocab words"""
|
||||
|
||||
space_id = tokenizer.encode(" ", add_special_tokens=False)[0]
|
||||
vocab_ids = list(tokenizer.vocab.values())
|
||||
vocab_ids = list(sorted(vocab_ids + vocab_ids))
|
||||
for i in range(1, len(vocab_ids), 2):
|
||||
vocab_ids[i] = space_id
|
||||
vocab_tokens = tokenizer.decode(vocab_ids, skip_special_tokens=True)
|
||||
vocab_tokens = vocab_tokens.split(" ")
|
||||
del vocab_ids
|
||||
|
||||
yield from vocab_tokens
|
||||
vocab = [w.strip() for w in vocab]
|
||||
yield from vocab
|
||||
|
||||
rand = random.Random()
|
||||
for m in range(iterations):
|
||||
@@ -217,14 +216,13 @@ def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations
|
||||
num_words = rand.randint(300, 400)
|
||||
for i in range(num_words):
|
||||
k = rand.randint(1, 3)
|
||||
tokens = rand.choices(vocab_tokens, k=k)
|
||||
tokens = [t.strip(" \n\r\t") for t in tokens]
|
||||
words = rand.choices(vocab, k=k)
|
||||
sep = rand.choice(" \n\r\t")
|
||||
text.append("".join(tokens) + sep)
|
||||
text.append("".join(words) + sep)
|
||||
yield "".join(text)
|
||||
|
||||
|
||||
def generator_random_bytes(iterations = 100) -> Iterator[str]:
|
||||
def generator_random_bytes(iterations=100) -> Iterator[str]:
|
||||
"""Brute force random bytes"""
|
||||
|
||||
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
||||
@@ -242,10 +240,10 @@ def generator_random_bytes(iterations = 100) -> Iterator[str]:
|
||||
yield "".join(text)
|
||||
|
||||
|
||||
def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, generator: Iterator[str]):
|
||||
def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
|
||||
|
||||
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
||||
for i, (a,b) in enumerate(zip(ids1, ids2)):
|
||||
for i, (a, b) in enumerate(zip(ids1, ids2)):
|
||||
if a != b:
|
||||
return i
|
||||
if len(ids1) == len(ids2):
|
||||
@@ -255,15 +253,12 @@ def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerB
|
||||
t0 = time.perf_counter()
|
||||
logger.info("%s: %s" % (generator.__name__, "ini"))
|
||||
for text in generator:
|
||||
ids1 = model.tokenize(text, add_special=False, parse_special=False)
|
||||
ids2 = tokenizer.encode(text, add_special_tokens=False)
|
||||
ids1 = func_tokenize1(text)
|
||||
ids2 = func_tokenize2(text)
|
||||
if ids1 != ids2:
|
||||
i = find_first_mismatch(ids1, ids2)
|
||||
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
|
||||
ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
|
||||
text2 = tokenizer.decode(ids2, skip_special_tokens=True)
|
||||
assert (text2 in text)
|
||||
logger.info(" Text: " + repr(text2))
|
||||
logger.info(" TokenIDs: " + str(ids1))
|
||||
logger.info(" Expected: " + str(ids2))
|
||||
raise Exception()
|
||||
@@ -271,25 +266,37 @@ def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerB
|
||||
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main(argv: list[str] = None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
|
||||
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
|
||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
|
||||
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=2048))
|
||||
|
||||
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
|
||||
|
||||
test_compare_tokenizer(model, tokenizer, generator_custom_text())
|
||||
test_compare_tokenizer(model, tokenizer, generator_custom_text_edge_cases())
|
||||
test_compare_tokenizer(model, tokenizer, generator_random_chars(10_000))
|
||||
test_compare_tokenizer(model, tokenizer, generator_random_vocab_chars(tokenizer, 10_000))
|
||||
test_compare_tokenizer(model, tokenizer, generator_random_vocab_tokens(tokenizer, 10_000))
|
||||
# test_compare_tokenizer(model, tokenizer, generator_random_bytes(10_000)) # FAIL
|
||||
def func_tokenize2(text: str):
|
||||
return tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
parse_special = all(len(func_tokenize2(t)) == 1 for t in tokenizer.all_special_tokens)
|
||||
|
||||
def func_tokenize1(text: str):
|
||||
return model.tokenize(text, add_special=False, parse_special=parse_special)
|
||||
|
||||
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
|
||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
|
||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
|
||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
|
||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
|
||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
|
||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 10_000))
|
||||
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
|
||||
|
||||
model.free()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
9138
unicode-data.cpp
9138
unicode-data.cpp
File diff suppressed because it is too large
Load Diff
@@ -1,17 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_number;
|
||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter;
|
||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_separator;
|
||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
|
||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark;
|
||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;
|
||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
|
||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
|
||||
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
|
||||
extern const std::map<char32_t, char32_t> unicode_map_lowercase;
|
||||
struct range_nfd {
|
||||
uint32_t first;
|
||||
uint32_t last;
|
||||
uint32_t nfd;
|
||||
};
|
||||
|
||||
static const uint32_t MAX_CODEPOINTS = 0x110000;
|
||||
|
||||
extern const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
|
||||
extern const std::unordered_set<uint32_t> unicode_set_whitespace;
|
||||
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase;
|
||||
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase;
|
||||
extern const std::vector<range_nfd> unicode_ranges_nfd;
|
||||
|
||||
200
unicode.cpp
200
unicode.cpp
@@ -1,4 +1,4 @@
|
||||
#include "unicode.h"
|
||||
#include "unicode.h"
|
||||
#include "unicode-data.h"
|
||||
|
||||
#include <cassert>
|
||||
@@ -109,57 +109,49 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
|
||||
// return result;
|
||||
//}
|
||||
|
||||
static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
|
||||
std::unordered_map<uint32_t, int> cpt_types;
|
||||
for (auto p : unicode_ranges_number) {
|
||||
for (auto i = p.first; i <= p.second; ++i) {
|
||||
cpt_types[i] = CODEPOINT_TYPE_NUMBER;
|
||||
static std::vector<codepoint_flags> unicode_cpt_flags_array() {
|
||||
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
|
||||
|
||||
assert (unicode_ranges_flags.front().first == 0);
|
||||
assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
|
||||
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
|
||||
const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags
|
||||
const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
|
||||
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
|
||||
cpt_flags[cpt] = range_ini.second;
|
||||
}
|
||||
}
|
||||
for (auto p : unicode_ranges_letter) {
|
||||
for (auto i = p.first; i <= p.second; ++i) {
|
||||
cpt_types[i] = CODEPOINT_TYPE_LETTER;
|
||||
}
|
||||
|
||||
for (auto cpt : unicode_set_whitespace) {
|
||||
cpt_flags[cpt].is_whitespace = true;
|
||||
}
|
||||
for (auto p : unicode_ranges_separator) {
|
||||
for (auto i = p.first; i <= p.second; ++i) {
|
||||
cpt_types[i] = CODEPOINT_TYPE_SEPARATOR;
|
||||
}
|
||||
|
||||
for (auto p : unicode_map_lowercase) {
|
||||
cpt_flags[p.second].is_lowercase = true;
|
||||
}
|
||||
for (auto p : unicode_ranges_accent_mark) {
|
||||
for (auto i = p.first; i <= p.second; ++i) {
|
||||
cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
|
||||
}
|
||||
|
||||
for (auto p : unicode_map_uppercase) {
|
||||
cpt_flags[p.second].is_uppercase = true;
|
||||
}
|
||||
for (auto p : unicode_ranges_punctuation) {
|
||||
for (auto i = p.first; i <= p.second; ++i) {
|
||||
cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION;
|
||||
}
|
||||
|
||||
for (auto &range : unicode_ranges_nfd) { // start, last, nfd
|
||||
cpt_flags[range.nfd].is_nfd = true;
|
||||
}
|
||||
for (auto p : unicode_ranges_symbol) {
|
||||
for (auto i = p.first; i <= p.second; ++i) {
|
||||
cpt_types[i] = CODEPOINT_TYPE_SYMBOL;
|
||||
}
|
||||
}
|
||||
for (auto p : unicode_ranges_control) {
|
||||
for (auto i = p.first; i <= p.second; ++i) {
|
||||
cpt_types[i] = CODEPOINT_TYPE_CONTROL;
|
||||
}
|
||||
}
|
||||
return cpt_types;
|
||||
|
||||
return cpt_flags;
|
||||
}
|
||||
|
||||
static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
|
||||
std::unordered_map<uint8_t, std::string> map;
|
||||
for (int ch = u'!'; ch <= u'~'; ++ch) {
|
||||
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
|
||||
assert(0 <= ch && ch < 256);
|
||||
map[ch] = unicode_cpt_to_utf8(ch);
|
||||
}
|
||||
for (int ch = u'¡'; ch <= u'¬'; ++ch) {
|
||||
for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
|
||||
assert(0 <= ch && ch < 256);
|
||||
map[ch] = unicode_cpt_to_utf8(ch);
|
||||
}
|
||||
for (int ch = u'®'; ch <= u'ÿ'; ++ch) {
|
||||
for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
|
||||
assert(0 <= ch && ch < 256);
|
||||
map[ch] = unicode_cpt_to_utf8(ch);
|
||||
}
|
||||
@@ -175,15 +167,15 @@ static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
|
||||
|
||||
static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
|
||||
std::unordered_map<std::string, uint8_t> map;
|
||||
for (int ch = u'!'; ch <= u'~'; ++ch) {
|
||||
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
|
||||
assert(0 <= ch && ch < 256);
|
||||
map[unicode_cpt_to_utf8(ch)] = ch;
|
||||
}
|
||||
for (int ch = u'¡'; ch <= u'¬'; ++ch) {
|
||||
for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
|
||||
assert(0 <= ch && ch < 256);
|
||||
map[unicode_cpt_to_utf8(ch)] = ch;
|
||||
}
|
||||
for (int ch = u'®'; ch <= u'ÿ'; ++ch) {
|
||||
for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
|
||||
assert(0 <= ch && ch < 256);
|
||||
map[unicode_cpt_to_utf8(ch)] = ch;
|
||||
}
|
||||
@@ -238,8 +230,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
||||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
||||
};
|
||||
|
||||
auto _get_cpt_type = [&] (const size_t pos) -> int {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
|
||||
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
||||
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
@@ -261,7 +254,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
||||
|
||||
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
||||
const char32_t cpt = _get_cpt(pos);
|
||||
const int cpt_type = _get_cpt_type(pos);
|
||||
const auto flags = _get_flags(pos);
|
||||
|
||||
// regex: 's|'t|'re|'ve|'m|'ll|'d
|
||||
if (cpt == '\'' && pos+1 < offset_end) {
|
||||
@@ -281,39 +274,37 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
||||
}
|
||||
}
|
||||
|
||||
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
|
||||
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
|
||||
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
||||
// regex: <space>?\p{L}+
|
||||
if (cpt2_type == CODEPOINT_TYPE_LETTER) {
|
||||
if (flags2.is_letter) {
|
||||
pos += (cpt == ' ');
|
||||
while (cpt2_type == CODEPOINT_TYPE_LETTER) {
|
||||
cpt2_type = _get_cpt_type(++pos);
|
||||
while (flags2.is_letter) {
|
||||
flags2 = _get_flags(++pos);
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
// regex: <space>?\p{N}+
|
||||
if (cpt2_type == CODEPOINT_TYPE_NUMBER) {
|
||||
if (flags2.is_number) {
|
||||
pos += (cpt == ' ');
|
||||
while (cpt2_type == CODEPOINT_TYPE_NUMBER) {
|
||||
cpt2_type = _get_cpt_type(++pos);
|
||||
while (flags2.is_number) {
|
||||
flags2 = _get_flags(++pos);
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
// regex: <space>?[^\s\p{L}\p{N}]+
|
||||
if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
|
||||
if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
||||
pos += (cpt == ' ');
|
||||
while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
|
||||
cpt2_type = _get_cpt_type(++pos);
|
||||
cpt2 = _get_cpt(pos);
|
||||
while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
||||
flags2 = _get_flags(++pos);
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t num_whitespaces = 0;
|
||||
while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
|
||||
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
||||
num_whitespaces++;
|
||||
}
|
||||
|
||||
@@ -357,8 +348,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
||||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
||||
};
|
||||
|
||||
auto _get_cpt_type = [&] (const size_t pos) -> int {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
|
||||
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
||||
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
@@ -380,7 +372,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
||||
|
||||
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
||||
const char32_t cpt = _get_cpt(pos);
|
||||
const int cpt_type = _get_cpt_type(pos);
|
||||
const auto flags = _get_flags(pos);
|
||||
|
||||
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
|
||||
if (cpt == '\'' && pos+1 < offset_end) {
|
||||
@@ -401,10 +393,10 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
||||
}
|
||||
|
||||
// regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
|
||||
if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_NUMBER) {
|
||||
if (cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) { // one or more letters
|
||||
if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) {
|
||||
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
|
||||
pos++;
|
||||
while (_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER) {
|
||||
while (_get_flags(pos).is_letter) {
|
||||
pos++;
|
||||
}
|
||||
_add_token(pos);
|
||||
@@ -413,9 +405,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
||||
}
|
||||
|
||||
// regex: \p{N}{1,3}
|
||||
if (cpt_type == CODEPOINT_TYPE_NUMBER) {
|
||||
if (flags.is_number) {
|
||||
size_t ini = pos;
|
||||
while (_get_cpt_type(pos) == CODEPOINT_TYPE_NUMBER) {
|
||||
while (_get_flags(pos).is_number) {
|
||||
if (++pos - ini >= 3 ) {
|
||||
_add_token(pos);
|
||||
ini = pos;
|
||||
@@ -426,14 +418,13 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
||||
}
|
||||
|
||||
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
|
||||
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
|
||||
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
|
||||
if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
|
||||
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
||||
if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
||||
pos += (cpt == ' ');
|
||||
while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
|
||||
cpt2_type = _get_cpt_type(++pos);
|
||||
cpt2 = _get_cpt(pos);
|
||||
while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
||||
flags2 = _get_flags(++pos);
|
||||
}
|
||||
char32_t cpt2 = _get_cpt(pos);
|
||||
while (cpt2 == '\r' || cpt2 == '\n') {
|
||||
cpt2 = _get_cpt(++pos);
|
||||
}
|
||||
@@ -443,7 +434,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
||||
|
||||
size_t num_whitespaces = 0;
|
||||
size_t last_end_r_or_n = 0;
|
||||
while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
|
||||
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
||||
char32_t cpt2 = _get_cpt(pos+num_whitespaces);
|
||||
if (cpt2 == '\r' || cpt2 == '\n') {
|
||||
last_end_r_or_n = pos + num_whitespaces + 1;
|
||||
@@ -589,15 +580,14 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
|
||||
}
|
||||
|
||||
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
|
||||
std::vector<uint32_t> result;
|
||||
result.reserve(cpts.size());
|
||||
auto comp = [] (const uint32_t cpt, const range_nfd & range) {
|
||||
return cpt < range.first;
|
||||
};
|
||||
std::vector<uint32_t> result(cpts.size());
|
||||
for (size_t i = 0; i < cpts.size(); ++i) {
|
||||
auto it = unicode_map_nfd.find(cpts[i]);
|
||||
if (it == unicode_map_nfd.end()) {
|
||||
result.push_back(cpts[i]);
|
||||
} else {
|
||||
result.push_back(it->second);
|
||||
}
|
||||
const uint32_t cpt = cpts[i];
|
||||
auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
|
||||
result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -611,31 +601,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
|
||||
return result;
|
||||
}
|
||||
|
||||
int unicode_cpt_type(uint32_t cp) {
|
||||
static std::unordered_map<uint32_t, int> cpt_types = unicode_cpt_type_map();
|
||||
const auto it = cpt_types.find(cp);
|
||||
return it == cpt_types.end() ? CODEPOINT_TYPE_UNIDENTIFIED : it->second;
|
||||
codepoint_flags unicode_cpt_flags(const uint32_t cp) {
|
||||
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
||||
static const auto cpt_flags = unicode_cpt_flags_array();
|
||||
return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
|
||||
}
|
||||
|
||||
int unicode_cpt_type(const std::string & utf8) {
|
||||
if (utf8.length() == 0) {
|
||||
return CODEPOINT_TYPE_UNIDENTIFIED;
|
||||
codepoint_flags unicode_cpt_flags(const std::string & utf8) {
|
||||
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
||||
if (utf8.empty()) {
|
||||
return undef; // undefined
|
||||
}
|
||||
size_t offset = 0;
|
||||
return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset));
|
||||
}
|
||||
|
||||
bool unicode_cpt_is_whitespace(uint32_t cp) {
|
||||
static const std::unordered_set<uint32_t> is_whitespace = [] {
|
||||
std::unordered_set<uint32_t> is_whitespace;
|
||||
for (auto p : unicode_ranges_whitespace) {
|
||||
for (auto i = p.first; i <= p.second; ++i) {
|
||||
is_whitespace.insert(i);
|
||||
}
|
||||
}
|
||||
return is_whitespace;
|
||||
}();
|
||||
return (bool)is_whitespace.count(cp);
|
||||
return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
|
||||
}
|
||||
|
||||
std::string unicode_byte_to_utf8(uint8_t byte) {
|
||||
@@ -656,21 +634,21 @@ char32_t unicode_tolower(char32_t cp) {
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
||||
// unicode categories
|
||||
static const std::map<std::string, int> k_ucat_enum = {
|
||||
{ "\\p{N}", CODEPOINT_TYPE_NUMBER },
|
||||
{ "\\p{L}", CODEPOINT_TYPE_LETTER },
|
||||
{ "\\p{P}", CODEPOINT_TYPE_PUNCTUATION },
|
||||
{ "\\p{N}", codepoint_flags::NUMBER },
|
||||
{ "\\p{L}", codepoint_flags::LETTER },
|
||||
{ "\\p{P}", codepoint_flags::PUNCTUATION },
|
||||
};
|
||||
|
||||
static const std::map<int, int> k_ucat_cpt = {
|
||||
{ CODEPOINT_TYPE_NUMBER, 0xD1 },
|
||||
{ CODEPOINT_TYPE_LETTER, 0xD2 },
|
||||
{ CODEPOINT_TYPE_PUNCTUATION, 0xD3 },
|
||||
{ codepoint_flags::NUMBER, 0xD1 },
|
||||
{ codepoint_flags::LETTER, 0xD2 },
|
||||
{ codepoint_flags::PUNCTUATION, 0xD3 },
|
||||
};
|
||||
|
||||
static const std::map<int, std::string> k_ucat_map = {
|
||||
{ CODEPOINT_TYPE_NUMBER, "\x30-\x39" }, // 0-9
|
||||
{ CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||
{ CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
|
||||
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||
};
|
||||
|
||||
// compute collapsed codepoints only if needed by at least one regex
|
||||
@@ -701,10 +679,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
||||
continue;
|
||||
}
|
||||
|
||||
const int cpt_type = unicode_cpt_type(cpts[i]);
|
||||
const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
|
||||
|
||||
if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) {
|
||||
text_collapsed[i] = k_ucat_cpt.at(cpt_type);
|
||||
if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
|
||||
text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
|
||||
} else {
|
||||
text_collapsed[i] = (char) 0xD0; // fallback
|
||||
}
|
||||
|
||||
56
unicode.h
56
unicode.h
@@ -4,24 +4,56 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#define CODEPOINT_TYPE_UNIDENTIFIED 0
|
||||
#define CODEPOINT_TYPE_NUMBER 1
|
||||
#define CODEPOINT_TYPE_LETTER 2
|
||||
#define CODEPOINT_TYPE_SEPARATOR 3
|
||||
#define CODEPOINT_TYPE_ACCENT_MARK 4
|
||||
#define CODEPOINT_TYPE_PUNCTUATION 5
|
||||
#define CODEPOINT_TYPE_SYMBOL 6
|
||||
#define CODEPOINT_TYPE_CONTROL 7
|
||||
struct codepoint_flags {
|
||||
enum {
|
||||
UNDEFINED = 0x0001,
|
||||
NUMBER = 0x0002, // regex: \p{N}
|
||||
LETTER = 0x0004, // regex: \p{L}
|
||||
SEPARATOR = 0x0008, // regex: \p{Z}
|
||||
ACCENT_MARK = 0x0010, // regex: \p{M}
|
||||
PUNCTUATION = 0x0020, // regex: \p{P}
|
||||
SYMBOL = 0x0040, // regex: \p{S}
|
||||
CONTROL = 0x0080, // regex: \p{C}
|
||||
MASK_CATEGORIES = 0x00FF,
|
||||
};
|
||||
|
||||
// codepoint type
|
||||
uint16_t is_undefined : 1;
|
||||
uint16_t is_number : 1; // regex: \p{N}
|
||||
uint16_t is_letter : 1; // regex: \p{L}
|
||||
uint16_t is_separator : 1; // regex: \p{Z}
|
||||
uint16_t is_accent_mark : 1; // regex: \p{M}
|
||||
uint16_t is_punctuation : 1; // regex: \p{P}
|
||||
uint16_t is_symbol : 1; // regex: \p{S}
|
||||
uint16_t is_control : 1; // regex: \p{C}
|
||||
// helper flags
|
||||
uint16_t is_whitespace : 1; // regex: \s
|
||||
uint16_t is_lowercase : 1;
|
||||
uint16_t is_uppercase : 1;
|
||||
uint16_t is_nfd : 1;
|
||||
|
||||
// decode from uint16
|
||||
inline codepoint_flags(const uint16_t flags=0) {
|
||||
*reinterpret_cast<uint16_t*>(this) = flags;
|
||||
}
|
||||
|
||||
inline uint16_t as_uint() const {
|
||||
return *reinterpret_cast<const uint16_t*>(this);
|
||||
}
|
||||
|
||||
inline uint16_t category_flag() const {
|
||||
return this->as_uint() & MASK_CATEGORIES;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
std::string unicode_cpt_to_utf8(uint32_t cp);
|
||||
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
|
||||
|
||||
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
|
||||
|
||||
int unicode_cpt_type(uint32_t cp);
|
||||
int unicode_cpt_type(const std::string & utf8);
|
||||
|
||||
bool unicode_cpt_is_whitespace(uint32_t cp);
|
||||
codepoint_flags unicode_cpt_flags(const uint32_t cp);
|
||||
codepoint_flags unicode_cpt_flags(const std::string & utf8);
|
||||
|
||||
std::string unicode_byte_to_utf8(uint8_t byte);
|
||||
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
||||
|
||||
Reference in New Issue
Block a user