mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
Compare commits
188 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00c361fe53 | ||
|
|
ec18edfcba | ||
|
|
7733409734 | ||
|
|
cd3c118908 | ||
|
|
649495c9d9 | ||
|
|
90c72a614a | ||
|
|
6eea666912 | ||
|
|
ff90508d68 | ||
|
|
0a4aeb927d | ||
|
|
2ba719519d | ||
|
|
7f8ef50cce | ||
|
|
3c136b21a3 | ||
|
|
beb1f0c503 | ||
|
|
def5404f26 | ||
|
|
fa0465954f | ||
|
|
5a6241feb0 | ||
|
|
c7af376c29 | ||
|
|
00425e2ed1 | ||
|
|
385c3da5e6 | ||
|
|
ab49f094d2 | ||
|
|
8c32d9d96d | ||
|
|
0874693b44 | ||
|
|
7d2add51d8 | ||
|
|
f698a79c63 | ||
|
|
47a268ea50 | ||
|
|
59d8d4e963 | ||
|
|
d82b7a7c1d | ||
|
|
03914c7ef8 | ||
|
|
3ce7a65c2f | ||
|
|
e072b2052e | ||
|
|
c6f7a423c8 | ||
|
|
2e7ef98f18 | ||
|
|
ddf9f94389 | ||
|
|
ff55414c42 | ||
|
|
73955f7d2a | ||
|
|
35cf8887e1 | ||
|
|
15d2b46b4d | ||
|
|
6bca76ff5e | ||
|
|
cd0e3a7a3b | ||
|
|
efaaccdd69 | ||
|
|
4abef75f2c | ||
|
|
c386114922 | ||
|
|
6783b11fb0 | ||
|
|
909072abcf | ||
|
|
cd8370b408 | ||
|
|
d21a76ac38 | ||
|
|
4fcd87cf7c | ||
|
|
b78db3bd50 | ||
|
|
142df17c9c | ||
|
|
e509411cf1 | ||
|
|
7cba58bbea | ||
|
|
5449367b21 | ||
|
|
1d594c295c | ||
|
|
eec1e33a9e | ||
|
|
879d673759 | ||
|
|
6ab4e50d9c | ||
|
|
2336cc4784 | ||
|
|
e6923caaec | ||
|
|
3e18dba9fd | ||
|
|
eeb5605de2 | ||
|
|
f3a848a3b1 | ||
|
|
b3b03a7baf | ||
|
|
583cb83416 | ||
|
|
05872ac885 | ||
|
|
55ab25caf5 | ||
|
|
064c90d843 | ||
|
|
b1846f1c8e | ||
|
|
d414db02d3 | ||
|
|
877566d512 | ||
|
|
3d07caa99b | ||
|
|
134e6940ca | ||
|
|
0543f928a3 | ||
|
|
b61de2b2df | ||
|
|
b8372eecd9 | ||
|
|
6ab8eacddf | ||
|
|
2d50b9d8cb | ||
|
|
697edfeead | ||
|
|
dbb852b549 | ||
|
|
5f55c385cb | ||
|
|
4902eebe33 | ||
|
|
923ae3c619 | ||
|
|
01ad35e6d6 | ||
|
|
fcb013847c | ||
|
|
d5bc1ad110 | ||
|
|
0c7220db56 | ||
|
|
96ac5a2329 | ||
|
|
bc809e9c53 | ||
|
|
54d83bbe85 | ||
|
|
4949ac0f18 | ||
|
|
3f3a4fb9c3 | ||
|
|
028f93ef98 | ||
|
|
8e9ddba610 | ||
|
|
23bc779a6e | ||
|
|
28175f857d | ||
|
|
9cc4080441 | ||
|
|
f1ffbba68e | ||
|
|
2370665e56 | ||
|
|
21d31e0810 | ||
|
|
dd0f321941 | ||
|
|
054a45c3d3 | ||
|
|
4c91f2633f | ||
|
|
92c0b387a9 | ||
|
|
2286a360ff | ||
|
|
1d321e592b | ||
|
|
196f5083ef | ||
|
|
5088b435d4 | ||
|
|
845f200b28 | ||
|
|
a7784a8b1d | ||
|
|
79bb743512 | ||
|
|
3ae282a06f | ||
|
|
5be353ec4a | ||
|
|
7d77f07325 | ||
|
|
1fa4551af0 | ||
|
|
2eba631b81 | ||
|
|
99c53d6558 | ||
|
|
07b0e7a5ac | ||
|
|
fd7353d5eb | ||
|
|
6fd4f95367 | ||
|
|
980b7cd17e | ||
|
|
c49daff5ba | ||
|
|
10e9780154 | ||
|
|
a045492088 | ||
|
|
1920345c3b | ||
|
|
561a3e2788 | ||
|
|
f40a2e5f11 | ||
|
|
bc4064cfea | ||
|
|
97cb3fd5ae | ||
|
|
ffa277a54c | ||
|
|
da95bf2a85 | ||
|
|
0de8878c96 | ||
|
|
38e2c1b412 | ||
|
|
cb44fc84e8 | ||
|
|
cb623de3fc | ||
|
|
7aaeedc098 | ||
|
|
3347e6d904 | ||
|
|
1a139644a8 | ||
|
|
2376b7758c | ||
|
|
dbed61294a | ||
|
|
80deff3648 | ||
|
|
8b1c339bd2 | ||
|
|
416e7c7f47 | ||
|
|
5b2093becc | ||
|
|
52e5d421f1 | ||
|
|
4db5641210 | ||
|
|
72bd7321a7 | ||
|
|
22e1ce2f81 | ||
|
|
1411d9275a | ||
|
|
662192e1dc | ||
|
|
24dc769f1b | ||
|
|
4dca015b7e | ||
|
|
9a8860cf5d | ||
|
|
9d3ef4809f | ||
|
|
c7b7db0445 | ||
|
|
1568d13c2c | ||
|
|
439342ea0b | ||
|
|
234ae7d7bd | ||
|
|
38eaf32af1 | ||
|
|
9b17d74ab7 | ||
|
|
e1fcf8b09b | ||
|
|
6cd0cf72ce | ||
|
|
d396b43748 | ||
|
|
45c6ef7307 | ||
|
|
2606b0adab | ||
|
|
307772fcda | ||
|
|
f1bad23f88 | ||
|
|
becc4816dd | ||
|
|
c4abcb2457 | ||
|
|
389ac78b26 | ||
|
|
a19bd6f7ce | ||
|
|
dd091e52f8 | ||
|
|
1215dde7b0 | ||
|
|
0cfb19166b | ||
|
|
2776db6c81 | ||
|
|
879dec341a | ||
|
|
97d5117217 | ||
|
|
a90eb94ca9 | ||
|
|
07751f8d44 | ||
|
|
ffb6f3d921 | ||
|
|
5d6838b74f | ||
|
|
92bb442ad9 | ||
|
|
374fe09cdd | ||
|
|
8e878f0cb4 | ||
|
|
00c94083b3 | ||
|
|
017eceed61 | ||
|
|
ee8dd5c658 | ||
|
|
1c398dc9ec | ||
|
|
52cf111b31 | ||
|
|
78010a0d52 |
@@ -3,7 +3,8 @@
|
||||
# ==============================================================================
|
||||
|
||||
# Define the CANN base image for easier version updates later
|
||||
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.1.rc1-910b-openeuler22.03-py3.10
|
||||
ARG CHIP_TYPE=910b
|
||||
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.3.rc1.alpha001-${CHIP_TYPE}-openeuler22.03-py3.11
|
||||
|
||||
# ==============================================================================
|
||||
# BUILD STAGE
|
||||
@@ -11,9 +12,6 @@ ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.1.rc1-910b-openeuler22.03-py3.10
|
||||
# ==============================================================================
|
||||
FROM ${CANN_BASE_IMAGE} AS build
|
||||
|
||||
# Define the Ascend chip model for compilation. Default is Ascend910B3
|
||||
ARG ASCEND_SOC_TYPE=Ascend910B3
|
||||
|
||||
# -- Install build dependencies --
|
||||
RUN yum install -y gcc g++ cmake make git libcurl-devel python3 python3-pip && \
|
||||
yum clean all && \
|
||||
@@ -36,20 +34,21 @@ ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH
|
||||
# For brevity, only core variables are listed here. You can paste the original ENV list here.
|
||||
|
||||
# -- Build llama.cpp --
|
||||
# Use the passed ASCEND_SOC_TYPE argument and add general build options
|
||||
# Use the passed CHIP_TYPE argument and add general build options
|
||||
ARG CHIP_TYPE
|
||||
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh --force \
|
||||
&& \
|
||||
cmake -B build \
|
||||
-DGGML_CANN=ON \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DSOC_TYPE=${ASCEND_SOC_TYPE} \
|
||||
-DSOC_TYPE=ascend${CHIP_TYPE} \
|
||||
. && \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
# -- Organize build artifacts for copying in later stages --
|
||||
# Create a lib directory to store all .so files
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so" -exec cp {} /app/lib \;
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
|
||||
# Create a full directory to store all executables and Python scripts
|
||||
RUN mkdir -p /app/full && \
|
||||
|
||||
@@ -20,7 +20,7 @@ RUN if [ "$TARGETARCH" = "amd64" ] || [ "$TARGETARCH" = "arm64" ]; then \
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so" -exec cp {} /app/lib \;
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
|
||||
@@ -25,7 +25,7 @@ RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so" -exec cp {} /app/lib \;
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
|
||||
@@ -21,7 +21,7 @@ RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so" -exec cp {} /app/lib \;
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
|
||||
@@ -32,7 +32,7 @@ RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so" -exec cp {} /app/lib \;
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
|
||||
@@ -45,7 +45,7 @@ RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
|
||||
&& cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib \
|
||||
&& find build -name "*.so" -exec cp {} /app/lib \;
|
||||
&& find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
ARG UBUNTU_VERSION=25.10
|
||||
ARG UBUNTU_VERSION=26.04
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
# Ref: https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html
|
||||
|
||||
# Install build tools
|
||||
RUN apt update && apt install -y git build-essential cmake wget xz-utils
|
||||
|
||||
@@ -20,7 +18,7 @@ RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=ON -DLLAMA_BUILD_TESTS=OFF -D
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so" -exec cp {} /app/lib \;
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
@@ -52,6 +50,7 @@ WORKDIR /app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
build-essential \
|
||||
git \
|
||||
python3 \
|
||||
python3-pip \
|
||||
|
||||
2
.github/copilot-instructions.md
vendored
2
.github/copilot-instructions.md
vendored
@@ -9,7 +9,7 @@ llama.cpp is a large-scale C/C++ project for efficient LLM (Large Language Model
|
||||
- **Size**: ~200k+ lines of code across 1000+ files
|
||||
- **Architecture**: Modular design with main library (`libllama`) and 40+ executable tools/examples
|
||||
- **Core dependency**: ggml tensor library (vendored in `ggml/` directory)
|
||||
- **Backends supported**: CPU (AVX/NEON optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
|
||||
- **Backends supported**: CPU (AVX/NEON/RVV optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
|
||||
- **License**: MIT
|
||||
|
||||
## Build Instructions
|
||||
|
||||
52
.github/workflows/build-amd.yml
vendored
52
.github/workflows/build-amd.yml
vendored
@@ -1,52 +0,0 @@
|
||||
name: CI (AMD)
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths: [
|
||||
'.github/workflows/build-amd.yml',
|
||||
'**/CMakeLists.txt',
|
||||
'**/.cmake',
|
||||
'**/*.h',
|
||||
'**/*.hpp',
|
||||
'**/*.c',
|
||||
'**/*.cpp',
|
||||
'**/*.cu',
|
||||
'**/*.cuh',
|
||||
'**/*.comp'
|
||||
]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
ggml-ci-x64-amd-vulkan:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-rocm:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
amd-smi static
|
||||
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
179
.github/workflows/build.yml
vendored
179
.github/workflows/build.yml
vendored
@@ -69,13 +69,6 @@ jobs:
|
||||
key: macOS-latest-cmake-arm64
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
brew install curl
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
@@ -83,6 +76,8 @@ jobs:
|
||||
cmake -B build \
|
||||
-DCMAKE_BUILD_RPATH="@loader_path" \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_BUILD_BORINGSSL=ON \
|
||||
-DGGML_METAL_USE_BF16=ON \
|
||||
-DGGML_METAL_EMBED_LIBRARY=OFF \
|
||||
-DGGML_METAL_SHADER_DEBUG=ON \
|
||||
@@ -110,13 +105,6 @@ jobs:
|
||||
key: macOS-latest-cmake-x64
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
brew install curl
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
@@ -126,6 +114,8 @@ jobs:
|
||||
cmake -B build \
|
||||
-DCMAKE_BUILD_RPATH="@loader_path" \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_BUILD_BORINGSSL=ON \
|
||||
-DGGML_METAL=OFF \
|
||||
-DGGML_RPC=ON \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=13.3
|
||||
@@ -151,13 +141,6 @@ jobs:
|
||||
key: macOS-latest-cmake-arm64-webgpu
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
brew install curl
|
||||
|
||||
- name: Dawn Dependency
|
||||
id: dawn-depends
|
||||
run: |
|
||||
@@ -217,7 +200,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
python3 python3-pip python3-dev \
|
||||
libjpeg-dev build-essential libcurl4-openssl-dev \
|
||||
libjpeg-dev build-essential libssl-dev \
|
||||
git-lfs
|
||||
|
||||
- name: Python Dependencies
|
||||
@@ -238,6 +221,8 @@ jobs:
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
@@ -294,13 +279,15 @@ jobs:
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential libcurl4-openssl-dev
|
||||
sudo apt-get install build-essential libssl-dev
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
if: ${{ matrix.sanitizer != 'THREAD' }}
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
|
||||
@@ -311,6 +298,8 @@ jobs:
|
||||
if: ${{ matrix.sanitizer == 'THREAD' }}
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||
@@ -335,7 +324,7 @@ jobs:
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential libcurl4-openssl-dev
|
||||
sudo apt-get install build-essential libssl-dev
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -343,6 +332,8 @@ jobs:
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_LLGUIDANCE=ON
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
@@ -373,12 +364,14 @@ jobs:
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential libcurl4-openssl-dev
|
||||
sudo apt-get install build-essential libssl-dev
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
@@ -405,12 +398,14 @@ jobs:
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get install -y glslc libvulkan-dev libcurl4-openssl-dev
|
||||
sudo apt-get install -y glslc libvulkan-dev libssl-dev
|
||||
|
||||
- name: Configure
|
||||
id: cmake_configure
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||
@@ -440,7 +435,7 @@ jobs:
|
||||
run: |
|
||||
sudo add-apt-repository -y ppa:kisak/kisak-mesa
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev
|
||||
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libssl-dev
|
||||
|
||||
- name: Get latest Vulkan SDK version
|
||||
id: vulkan_sdk_version
|
||||
@@ -466,6 +461,8 @@ jobs:
|
||||
run: |
|
||||
source ./vulkan_sdk/setup-env.sh
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DGGML_VULKAN=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
@@ -497,7 +494,7 @@ jobs:
|
||||
run: |
|
||||
sudo add-apt-repository -y ppa:kisak/kisak-mesa
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev
|
||||
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libssl-dev
|
||||
|
||||
- name: Get latest Vulkan SDK version
|
||||
id: vulkan_sdk_version
|
||||
@@ -537,7 +534,10 @@ jobs:
|
||||
id: cmake_build
|
||||
run: |
|
||||
export Dawn_DIR=dawn/lib64/cmake/Dawn
|
||||
cmake -B build -DGGML_WEBGPU=ON
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DGGML_WEBGPU=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
- name: Test
|
||||
@@ -560,7 +560,7 @@ jobs:
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libcurl4-openssl-dev rocwmma-dev
|
||||
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libssl-dev rocwmma-dev
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -572,6 +572,8 @@ jobs:
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON \
|
||||
-DGGML_HIP=ON
|
||||
@@ -590,7 +592,7 @@ jobs:
|
||||
id: depends
|
||||
run: |
|
||||
apt-get update
|
||||
apt-get install -y build-essential git cmake libcurl4-openssl-dev
|
||||
apt-get install -y build-essential git cmake libssl-dev
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -602,6 +604,8 @@ jobs:
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DGGML_MUSA=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
@@ -626,7 +630,7 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp libcurl4-openssl-dev
|
||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp libssl-dev
|
||||
|
||||
- name: install oneAPI MKL library
|
||||
shell: bash
|
||||
@@ -648,6 +652,8 @@ jobs:
|
||||
run: |
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
@@ -674,7 +680,7 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp libcurl4-openssl-dev
|
||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp libssl-dev
|
||||
|
||||
- name: install oneAPI MKL library
|
||||
shell: bash
|
||||
@@ -696,6 +702,8 @@ jobs:
|
||||
run: |
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
@@ -722,12 +730,6 @@ jobs:
|
||||
key: macOS-latest-cmake-ios
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
@@ -759,12 +761,6 @@ jobs:
|
||||
key: macOS-latest-cmake-tvos
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
@@ -790,12 +786,6 @@ jobs:
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
@@ -838,12 +828,6 @@ jobs:
|
||||
name: llama-xcframework
|
||||
path: build-apple/llama.xcframework/
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
|
||||
- name: Build llama.cpp with CMake
|
||||
id: cmake_build
|
||||
run: |
|
||||
@@ -995,21 +979,12 @@ jobs:
|
||||
-DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release"
|
||||
cmake --build build-arm64-release --target install --config release
|
||||
|
||||
- name: libCURL
|
||||
id: get_libcurl
|
||||
uses: ./.github/actions/windows-setup-curl
|
||||
with:
|
||||
architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }}
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
cmake -S . -B build ${{ matrix.defines }} `
|
||||
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include"
|
||||
-DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON
|
||||
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS}
|
||||
cp $env:CURL_PATH/bin/libcurl-*.dll build/bin/Release
|
||||
|
||||
- name: Add libopenblas.dll
|
||||
id: add_libopenblas_dll
|
||||
@@ -1053,7 +1028,7 @@ jobs:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
run: |
|
||||
apt update
|
||||
apt install -y cmake build-essential ninja-build libgomp1 git libcurl4-openssl-dev
|
||||
apt install -y cmake build-essential ninja-build libgomp1 git libssl-dev
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1064,10 +1039,12 @@ jobs:
|
||||
- name: Build with CMake
|
||||
run: |
|
||||
cmake -S . -B build -G Ninja \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_CUDA_ARCHITECTURES=89-real \
|
||||
-DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DGGML_CUDA=ON
|
||||
cmake --build build
|
||||
@@ -1101,25 +1078,20 @@ jobs:
|
||||
run: |
|
||||
choco install ninja
|
||||
|
||||
- name: libCURL
|
||||
id: get_libcurl
|
||||
uses: ./.github/actions/windows-setup-curl
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
shell: cmd
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64
|
||||
cmake -S . -B build -G "Ninja Multi-Config" ^
|
||||
-DLLAMA_BUILD_SERVER=ON ^
|
||||
-DLLAMA_CURL=OFF ^
|
||||
-DLLAMA_BUILD_BORINGSSL=ON ^
|
||||
-DGGML_NATIVE=OFF ^
|
||||
-DGGML_BACKEND_DL=ON ^
|
||||
-DGGML_CPU_ALL_VARIANTS=ON ^
|
||||
-DGGML_CUDA=ON ^
|
||||
-DGGML_RPC=ON ^
|
||||
-DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include"
|
||||
-DGGML_RPC=ON
|
||||
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
|
||||
cmake --build build --config Release -j %NINJA_JOBS% -t ggml
|
||||
cmake --build build --config Release
|
||||
@@ -1151,7 +1123,7 @@ jobs:
|
||||
run: |
|
||||
scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL
|
||||
|
||||
# TODO: add libcurl support ; we will also need to modify win-build-sycl.bat to accept user-specified args
|
||||
# TODO: add ssl support ; we will also need to modify win-build-sycl.bat to accept user-specified args
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -1208,14 +1180,8 @@ jobs:
|
||||
key: ${{ github.job }}
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: libCURL
|
||||
id: get_libcurl
|
||||
uses: ./.github/actions/windows-setup-curl
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
|
||||
@@ -1224,11 +1190,12 @@ jobs:
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-${{ env.ROCM_VERSION }}/include/" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DLLAMA_CURL=OFF `
|
||||
-DLLAMA_BUILD_BORINGSSL=ON `
|
||||
-DROCM_DIR="${env:HIP_PATH}" `
|
||||
-DGGML_HIP=ON `
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON `
|
||||
-DGGML_RPC=ON `
|
||||
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include"
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
|
||||
|
||||
ios-xcode-build:
|
||||
@@ -1390,14 +1357,10 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: [x86, aarch64]
|
||||
cann:
|
||||
- '8.1.RC1.alpha001-910b-openeuler22.03-py3.10'
|
||||
device:
|
||||
- 'ascend910b3'
|
||||
build:
|
||||
- 'Release'
|
||||
chip_type: ['910b', '310p']
|
||||
build: ['Release']
|
||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
container: ascendai/cann:${{ matrix.cann }}
|
||||
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -1414,7 +1377,7 @@ jobs:
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
||||
-DGGML_CANN=on \
|
||||
-DSOC_TYPE=${{ matrix.device }}
|
||||
-DSOC_TYPE=ascend${{ matrix.chip_type }}
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
# TODO: simplify the following workflows using a matrix
|
||||
@@ -1599,6 +1562,34 @@ jobs:
|
||||
run: |
|
||||
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-vulkan:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-rocm:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
amd-smi static
|
||||
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-mac-metal:
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
|
||||
52
.github/workflows/check-vendor.yml
vendored
Normal file
52
.github/workflows/check-vendor.yml
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
name: Check vendor
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths: [
|
||||
'vendor/**',
|
||||
'scripts/sync_vendor.py'
|
||||
]
|
||||
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
paths: [
|
||||
'vendor/**',
|
||||
'scripts/sync_vendor.py'
|
||||
]
|
||||
|
||||
jobs:
|
||||
check-vendor:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
- name: Run vendor sync
|
||||
run: |
|
||||
set -euo pipefail
|
||||
python3 scripts/sync_vendor.py
|
||||
|
||||
- name: Check for changes
|
||||
run: |
|
||||
set -euo pipefail
|
||||
# detect modified or untracked files
|
||||
changed=$(git status --porcelain --untracked-files=all || true)
|
||||
if [ -n "$changed" ]; then
|
||||
echo "Vendor sync modified files:"
|
||||
echo "$changed" | awk '{ print $2 }' | sed '/^$/d'
|
||||
echo "Failing because vendor files mismatch. Please update scripts/sync_vendor.py"
|
||||
exit 1
|
||||
else
|
||||
echo "Vendor files are up-to-date."
|
||||
fi
|
||||
46
.github/workflows/release.yml
vendored
46
.github/workflows/release.yml
vendored
@@ -693,6 +693,51 @@ jobs:
|
||||
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
|
||||
name: llama-${{ steps.tag.outputs.name }}-xcframework
|
||||
|
||||
openEuler-cann:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: [x86, aarch64]
|
||||
chip_type: ['910b', '310p']
|
||||
build: ['Release']
|
||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
yum update -y
|
||||
yum install -y git gcc gcc-c++ make cmake libcurl-devel
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH}
|
||||
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
||||
-DGGML_CANN=on \
|
||||
-DSOC_TYPE=ascend${{ matrix.chip_type }}
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
- name: Determine tag name
|
||||
id: tag
|
||||
uses: ./.github/actions/get-tag-name
|
||||
|
||||
- name: Pack artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
zip -r llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip ./build/bin/*
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||
|
||||
release:
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
|
||||
@@ -714,6 +759,7 @@ jobs:
|
||||
- macOS-arm64
|
||||
- macOS-x64
|
||||
- ios-xcode-build
|
||||
- openEuler-cann
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
27
.github/workflows/server.yml
vendored
27
.github/workflows/server.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
||||
curl \
|
||||
wget \
|
||||
language-pack-en \
|
||||
libcurl4-openssl-dev
|
||||
libssl-dev
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
@@ -209,7 +209,7 @@ jobs:
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run UI tests
|
||||
run: npm run test:ui
|
||||
run: npm run test:ui -- --testTimeout=60000
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run E2E tests
|
||||
@@ -242,7 +242,7 @@ jobs:
|
||||
curl \
|
||||
wget \
|
||||
language-pack-en \
|
||||
libcurl4-openssl-dev
|
||||
libssl-dev
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
@@ -283,6 +283,8 @@ jobs:
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_BUILD_SERVER=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||
@@ -295,6 +297,8 @@ jobs:
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_BUILD_SERVER=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
|
||||
@@ -306,6 +310,8 @@ jobs:
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_BUILD_SERVER=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ;
|
||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
||||
@@ -345,16 +351,10 @@ jobs:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: libCURL
|
||||
id: get_libcurl
|
||||
uses: ./.github/actions/windows-setup-curl
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
cmake -B build -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include"
|
||||
cmake -B build -DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON
|
||||
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
||||
|
||||
- name: Python setup
|
||||
@@ -368,13 +368,6 @@ jobs:
|
||||
run: |
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
|
||||
- name: Copy Libcurl
|
||||
id: prepare_libcurl
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
cp $env:CURL_PATH/bin/libcurl-x64.dll ./build/bin/Release/libcurl-x64.dll
|
||||
|
||||
- name: Tests
|
||||
id: server_integration_tests
|
||||
if: ${{ !matrix.disabled_on_pr || !github.event.pull_request }}
|
||||
|
||||
108
.gitignore
vendored
108
.gitignore
vendored
@@ -20,52 +20,40 @@
|
||||
*.so
|
||||
*.swp
|
||||
*.tmp
|
||||
*.DS_Store
|
||||
|
||||
# IDE / OS
|
||||
|
||||
.cache/
|
||||
.ccls-cache/
|
||||
.direnv/
|
||||
.DS_Store
|
||||
.envrc
|
||||
.idea/
|
||||
.swiftpm
|
||||
.vs/
|
||||
.vscode/
|
||||
nppBackup
|
||||
/.cache/
|
||||
/.ccls-cache/
|
||||
/.direnv/
|
||||
/.envrc
|
||||
/.idea/
|
||||
/.swiftpm
|
||||
/.vs/
|
||||
/.vscode/
|
||||
/nppBackup
|
||||
|
||||
|
||||
# Coverage
|
||||
|
||||
gcovr-report/
|
||||
lcov-report/
|
||||
/gcovr-report/
|
||||
/lcov-report/
|
||||
|
||||
# Build Artifacts
|
||||
|
||||
tags
|
||||
.build/
|
||||
build*
|
||||
release
|
||||
debug
|
||||
!build-info.cmake
|
||||
!build-info.cpp.in
|
||||
!build-info.sh
|
||||
!build.zig
|
||||
!docs/build.md
|
||||
/tags
|
||||
/.build/
|
||||
/build*
|
||||
/release
|
||||
/debug
|
||||
/libllama.so
|
||||
/llama-*
|
||||
/vulkan-shaders-gen
|
||||
android-ndk-*
|
||||
arm_neon.h
|
||||
cmake-build-*
|
||||
CMakeSettings.json
|
||||
compile_commands.json
|
||||
ggml-metal-embed.metal
|
||||
llama-batched-swift
|
||||
/rpc-server
|
||||
out/
|
||||
tmp/
|
||||
autogen-*.md
|
||||
/out/
|
||||
/tmp/
|
||||
/autogen-*.md
|
||||
|
||||
# Deprecated
|
||||
|
||||
@@ -74,44 +62,38 @@ autogen-*.md
|
||||
|
||||
# CI
|
||||
|
||||
!.github/workflows/*.yml
|
||||
!/.github/workflows/*.yml
|
||||
|
||||
# Models
|
||||
|
||||
models/*
|
||||
models-mnt
|
||||
!models/.editorconfig
|
||||
!models/ggml-vocab-*.gguf*
|
||||
!models/templates
|
||||
/models/*
|
||||
/models-mnt
|
||||
!/models/.editorconfig
|
||||
!/models/ggml-vocab-*.gguf*
|
||||
!/models/templates
|
||||
|
||||
# Zig
|
||||
zig-out/
|
||||
zig-cache/
|
||||
|
||||
# Logs
|
||||
|
||||
ppl-*.txt
|
||||
qnt-*.txt
|
||||
perf-*.txt
|
||||
/zig-out/
|
||||
/zig-cache/
|
||||
|
||||
# Examples
|
||||
|
||||
examples/jeopardy/results.txt
|
||||
tools/server/*.css.hpp
|
||||
tools/server/*.html.hpp
|
||||
tools/server/*.js.hpp
|
||||
tools/server/*.mjs.hpp
|
||||
tools/server/*.gz.hpp
|
||||
!build_64.sh
|
||||
!examples/*.bat
|
||||
!examples/*/*.kts
|
||||
!examples/*/*/*.kts
|
||||
!examples/sycl/*.bat
|
||||
!examples/sycl/*.sh
|
||||
/examples/jeopardy/results.txt
|
||||
/tools/server/*.css.hpp
|
||||
/tools/server/*.html.hpp
|
||||
/tools/server/*.js.hpp
|
||||
/tools/server/*.mjs.hpp
|
||||
/tools/server/*.gz.hpp
|
||||
!/build_64.sh
|
||||
!/examples/*.bat
|
||||
!/examples/*/*.kts
|
||||
!/examples/*/*/*.kts
|
||||
!/examples/sycl/*.bat
|
||||
!/examples/sycl/*.sh
|
||||
|
||||
# Server Web UI temporary files
|
||||
node_modules
|
||||
tools/server/webui/dist
|
||||
/tools/server/webui/node_modules
|
||||
/tools/server/webui/dist
|
||||
|
||||
# Python
|
||||
|
||||
@@ -147,8 +129,8 @@ poetry.toml
|
||||
# Local scripts
|
||||
/run-vim.sh
|
||||
/run-chat.sh
|
||||
.ccache/
|
||||
/.ccache/
|
||||
|
||||
# IDE
|
||||
*.code-workspace
|
||||
.windsurf/
|
||||
/*.code-workspace
|
||||
/.windsurf/
|
||||
|
||||
@@ -92,6 +92,7 @@ option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_
|
||||
|
||||
# 3rd party libs
|
||||
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
|
||||
option(LLAMA_HTTPLIB "llama: if libcurl is disabled, use httplib to download model from an URL" ON)
|
||||
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" OFF)
|
||||
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
|
||||
|
||||
@@ -200,7 +201,9 @@ endif()
|
||||
|
||||
if (LLAMA_BUILD_COMMON)
|
||||
add_subdirectory(common)
|
||||
add_subdirectory(vendor/cpp-httplib)
|
||||
if (LLAMA_HTTPLIB)
|
||||
add_subdirectory(vendor/cpp-httplib)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
||||
|
||||
31
CODEOWNERS
31
CODEOWNERS
@@ -2,10 +2,8 @@
|
||||
# multiplie collaborators per item can be specified
|
||||
|
||||
/.devops/*.Dockerfile @ngxson
|
||||
/.github/actions/ @slaren @CISC
|
||||
/.github/actions/ @CISC
|
||||
/.github/workflows/ @CISC
|
||||
/.github/workflows/release.yml @slaren
|
||||
/.github/workflows/winget.yml @slaren
|
||||
/ci/ @ggerganov
|
||||
/cmake/ @ggerganov
|
||||
/common/CMakeLists.txt @ggerganov
|
||||
@@ -40,21 +38,14 @@
|
||||
/examples/passkey/ @ggerganov
|
||||
/examples/retrieval/ @ggerganov
|
||||
/examples/save-load-state/ @ggerganov
|
||||
/examples/simple-chat/ @slaren
|
||||
/examples/simple/ @slaren
|
||||
/examples/speculative-simple/ @ggerganov
|
||||
/examples/speculative/ @ggerganov
|
||||
/ggml/cmake/ @ggerganov
|
||||
/ggml/include/ @ggerganov @slaren
|
||||
/ggml/src/ggml-alloc.c @slaren
|
||||
/ggml/src/ggml-backend* @slaren
|
||||
/ggml/src/ggml-blas/ @slaren
|
||||
/ggml/src/ggml-common.h @ggerganov @slaren
|
||||
/ggml/src/ggml-cpu/ @ggerganov @slaren
|
||||
/ggml/include/ @ggerganov
|
||||
/ggml/src/ggml-common.h @ggerganov
|
||||
/ggml/src/ggml-cpu/ @ggerganov
|
||||
/ggml/src/ggml-cpu/spacemit/ @alex-spacemit
|
||||
/ggml/src/ggml-cuda/common.cuh @slaren
|
||||
/ggml/src/ggml-cuda/fattn* @JohannesGaessler
|
||||
/ggml/src/ggml-cuda/ggml-cuda.cu @slaren
|
||||
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an
|
||||
/ggml/src/ggml-cuda/mmq.* @JohannesGaessler
|
||||
/ggml/src/ggml-cuda/mmvf.* @JohannesGaessler
|
||||
@@ -62,19 +53,19 @@
|
||||
/ggml/src/ggml-cuda/fattn-wmma* @IMbackK
|
||||
/ggml/src/ggml-hip/ @IMbackK
|
||||
/ggml/src/ggml-cuda/vendors/hip.h @IMbackK
|
||||
/ggml/src/ggml-impl.h @ggerganov @slaren
|
||||
/ggml/src/ggml-impl.h @ggerganov
|
||||
/ggml/src/ggml-metal/ @ggerganov
|
||||
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
|
||||
/ggml/src/ggml-hexagon/ @max-krasnyansky @lhez
|
||||
/ggml/src/ggml-opt.cpp @JohannesGaessler
|
||||
/ggml/src/ggml-quants.* @ggerganov
|
||||
/ggml/src/ggml-rpc/ @rgerganov
|
||||
/ggml/src/ggml-threading.* @ggerganov @slaren
|
||||
/ggml/src/ggml-threading.* @ggerganov
|
||||
/ggml/src/ggml-vulkan/ @0cc4m
|
||||
/ggml/src/ggml-webgpu/ @reeselevine
|
||||
/ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM
|
||||
/ggml/src/ggml.c @ggerganov @slaren
|
||||
/ggml/src/ggml.cpp @ggerganov @slaren
|
||||
/ggml/src/ggml.c @ggerganov
|
||||
/ggml/src/ggml.cpp @ggerganov
|
||||
/ggml/src/gguf.cpp @JohannesGaessler @Green-Sky
|
||||
/gguf-py/ @CISC
|
||||
/media/ @ggerganov
|
||||
@@ -86,15 +77,11 @@
|
||||
/src/llama-arch.* @CISC
|
||||
/src/llama-chat.* @ngxson
|
||||
/src/llama-graph.* @CISC
|
||||
/src/llama-model-loader.* @slaren
|
||||
/src/llama-model.* @CISC
|
||||
/src/llama-vocab.* @CISC
|
||||
/src/models/ @CISC
|
||||
/tests/ @ggerganov
|
||||
/tests/test-backend-ops.cpp @slaren
|
||||
/tests/test-thread-safety.cpp @slaren
|
||||
/tools/batched-bench/ @ggerganov
|
||||
/tools/llama-bench/ @slaren
|
||||
/tools/main/ @ggerganov
|
||||
/tools/mtmd/ @ngxson
|
||||
/tools/perplexity/ @ggerganov
|
||||
@@ -106,8 +93,6 @@
|
||||
/tools/tokenize/ @ggerganov
|
||||
/tools/tts/ @ggerganov
|
||||
/vendor/ @ggerganov
|
||||
/.clang-format @slaren
|
||||
/.clang-tidy @slaren
|
||||
/AUTHORS @ggerganov
|
||||
/CMakeLists.txt @ggerganov
|
||||
/CONTRIBUTING.md @ggerganov
|
||||
|
||||
@@ -19,6 +19,7 @@ The project differentiates between 3 levels of contributors:
|
||||
- If your PR becomes stale, don't hesitate to ping the maintainers in the comments
|
||||
- Maintainers will rely on your insights and approval when making a final decision to approve and merge a PR
|
||||
- Consider adding yourself to [CODEOWNERS](CODEOWNERS) to indicate your availability for reviewing related PRs
|
||||
- Using AI to generate PRs is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before publishing the PR. Note that trivial tab autocompletions do not require disclosure.
|
||||
|
||||
# Pull requests (for maintainers)
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ range of hardware - locally and in the cloud.
|
||||
- Plain C/C++ implementation without any dependencies
|
||||
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
|
||||
- AVX, AVX2, AVX512 and AMX support for x86 architectures
|
||||
- RVV, ZVFH, ZFH and ZICBOP support for RISC-V architectures
|
||||
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
|
||||
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA)
|
||||
- Vulkan and SYCL backend support
|
||||
@@ -241,6 +242,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
|
||||
- [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage
|
||||
- [Styled Lines](https://marketplace.unity.com/packages/tools/generative-ai/styled-lines-llama-cpp-model-292902) (proprietary licensed, async wrapper of inference part for game development in Unity3d with pre-built Mobile and Web platform wrappers and a model example)
|
||||
- [unslothai/unsloth](https://github.com/unslothai/unsloth) – 🦥 exports/saves fine-tuned and trained models to GGUF (Apache-2.0)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -611,3 +613,4 @@ $ echo "source ~/.llama-completion.bash" >> ~/.bashrc
|
||||
- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License
|
||||
- [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html)
|
||||
- [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain
|
||||
- [subprocess.h](https://github.com/sheredom/subprocess.h) - Single-header process launching solution for C and C++ - Public domain
|
||||
|
||||
@@ -65,4 +65,6 @@ However, If you have discovered a security vulnerability in this project, please
|
||||
|
||||
Please disclose it as a private [security advisory](https://github.com/ggml-org/llama.cpp/security/advisories/new).
|
||||
|
||||
Please note that using AI to identify vulnerabilities and generate reports is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before submitting the report.
|
||||
|
||||
A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure.
|
||||
|
||||
@@ -454,6 +454,8 @@ cmake -B build-visionos -G Xcode \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_HTTPLIB=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos --config Release -- -quiet
|
||||
|
||||
@@ -468,6 +470,8 @@ cmake -B build-visionos-sim -G Xcode \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_HTTPLIB=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos-sim --config Release -- -quiet
|
||||
|
||||
|
||||
16
ci/run.sh
16
ci/run.sh
@@ -45,7 +45,7 @@ sd=`dirname $0`
|
||||
cd $sd/../
|
||||
SRC=`pwd`
|
||||
|
||||
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON"
|
||||
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON"
|
||||
|
||||
if [ ! -z ${GG_BUILD_METAL} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON"
|
||||
@@ -428,10 +428,10 @@ function gg_run_qwen3_0_6b {
|
||||
|
||||
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -ngl 99 -c 1024 -b 512 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
|
||||
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
|
||||
function check_ppl {
|
||||
qnt="$1"
|
||||
@@ -523,8 +523,8 @@ function gg_run_embd_bge_small {
|
||||
|
||||
./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0
|
||||
|
||||
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
|
||||
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
|
||||
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
|
||||
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
|
||||
|
||||
set +e
|
||||
}
|
||||
@@ -564,7 +564,7 @@ function gg_run_rerank_tiny {
|
||||
model_f16="${path_models}/ggml-model-f16.gguf"
|
||||
|
||||
# for this model, the SEP token is "</s>"
|
||||
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
|
||||
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --no-op-offload --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
|
||||
|
||||
# sample output
|
||||
# rerank score 0: 0.029
|
||||
|
||||
@@ -50,6 +50,8 @@ add_library(${TARGET} STATIC
|
||||
base64.hpp
|
||||
chat-parser.cpp
|
||||
chat-parser.h
|
||||
chat-parser-xml-toolcall.h
|
||||
chat-parser-xml-toolcall.cpp
|
||||
chat.cpp
|
||||
chat.h
|
||||
common.cpp
|
||||
@@ -91,47 +93,12 @@ if (LLAMA_CURL)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
|
||||
include_directories(${CURL_INCLUDE_DIRS})
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARIES})
|
||||
else()
|
||||
elseif (LLAMA_HTTPLIB)
|
||||
# otherwise, use cpp-httplib
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
|
||||
endif()
|
||||
|
||||
if (LLAMA_OPENSSL)
|
||||
find_package(OpenSSL)
|
||||
if (OpenSSL_FOUND)
|
||||
include(CheckCSourceCompiles)
|
||||
set(SAVED_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES})
|
||||
set(CMAKE_REQUIRED_INCLUDES ${OPENSSL_INCLUDE_DIR})
|
||||
check_c_source_compiles("
|
||||
#include <openssl/opensslv.h>
|
||||
#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER)
|
||||
# if OPENSSL_VERSION_NUMBER < 0x1010107f
|
||||
# error bad version
|
||||
# endif
|
||||
#else
|
||||
# if OPENSSL_VERSION_NUMBER < 0x30000000L
|
||||
# error bad version
|
||||
# endif
|
||||
#endif
|
||||
int main() { return 0; }
|
||||
" OPENSSL_VERSION_SUPPORTED)
|
||||
set(CMAKE_REQUIRED_INCLUDES ${SAVED_CMAKE_REQUIRED_INCLUDES})
|
||||
if (OPENSSL_VERSION_SUPPORTED)
|
||||
message(STATUS "OpenSSL found: ${OPENSSL_VERSION}")
|
||||
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT)
|
||||
target_link_libraries(${TARGET} PUBLIC OpenSSL::SSL OpenSSL::Crypto)
|
||||
if (APPLE AND CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
||||
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
find_library(CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED)
|
||||
find_library(SECURITY_FRAMEWORK Security REQUIRED)
|
||||
target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK})
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "OpenSSL not found, SSL support disabled")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_LLGUIDANCE)
|
||||
include(ExternalProject)
|
||||
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
|
||||
|
||||
@@ -212,13 +212,13 @@ struct handle_model_result {
|
||||
static handle_model_result common_params_handle_model(
|
||||
struct common_params_model & model,
|
||||
const std::string & bearer_token,
|
||||
const std::string & model_path_default,
|
||||
bool offline) {
|
||||
handle_model_result result;
|
||||
// handle pre-fill default model path and url based on hf_repo and hf_file
|
||||
{
|
||||
if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths
|
||||
model.path = common_docker_resolve_model(model.docker_repo);
|
||||
model.name = model.docker_repo; // set name for consistency
|
||||
} else if (!model.hf_repo.empty()) {
|
||||
// short-hand to avoid specifying --hf-file -> default it to --model
|
||||
if (model.hf_file.empty()) {
|
||||
@@ -227,7 +227,8 @@ static handle_model_result common_params_handle_model(
|
||||
if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) {
|
||||
exit(1); // built without CURL, error message already printed
|
||||
}
|
||||
model.hf_repo = auto_detected.repo;
|
||||
model.name = model.hf_repo; // repo name with tag
|
||||
model.hf_repo = auto_detected.repo; // repo name without tag
|
||||
model.hf_file = auto_detected.ggufFile;
|
||||
if (!auto_detected.mmprojFile.empty()) {
|
||||
result.found_mmproj = true;
|
||||
@@ -257,8 +258,6 @@ static handle_model_result common_params_handle_model(
|
||||
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||
}
|
||||
|
||||
} else if (model.path.empty()) {
|
||||
model.path = model_path_default;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -405,7 +404,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
|
||||
// handle model and download
|
||||
{
|
||||
auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH, params.offline);
|
||||
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
|
||||
if (params.no_mmproj) {
|
||||
params.mmproj = {};
|
||||
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
|
||||
@@ -415,12 +414,18 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
// only download mmproj if the current example is using it
|
||||
for (auto & ex : mmproj_examples) {
|
||||
if (ctx_arg.ex == ex) {
|
||||
common_params_handle_model(params.mmproj, params.hf_token, "", params.offline);
|
||||
common_params_handle_model(params.mmproj, params.hf_token, params.offline);
|
||||
break;
|
||||
}
|
||||
}
|
||||
common_params_handle_model(params.speculative.model, params.hf_token, "", params.offline);
|
||||
common_params_handle_model(params.vocoder.model, params.hf_token, "", params.offline);
|
||||
common_params_handle_model(params.speculative.model, params.hf_token, params.offline);
|
||||
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
|
||||
}
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
}
|
||||
|
||||
if (params.escape) {
|
||||
@@ -694,6 +699,12 @@ static bool is_autoy(const std::string & value) {
|
||||
}
|
||||
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
|
||||
// default values specific to example
|
||||
// note: we place it here instead of inside server.cpp to allow llama-gen-docs to pick it up
|
||||
if (ex == LLAMA_EXAMPLE_SERVER) {
|
||||
params.use_jinja = true;
|
||||
}
|
||||
|
||||
// load dynamic backends
|
||||
ggml_backend_load_all();
|
||||
|
||||
@@ -974,7 +985,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.kv_unified = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_SPLIT"));
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED"));
|
||||
add_opt(common_arg(
|
||||
{"--no-context-shift"},
|
||||
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||
@@ -1232,6 +1243,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
const auto sampler_names = string_split<std::string>(value, ';');
|
||||
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
@@ -1261,6 +1273,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.temp = std::stof(value);
|
||||
params.sampling.temp = std::max(params.sampling.temp, 0.0f);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
@@ -1268,6 +1281,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k),
|
||||
[](common_params & params, int value) {
|
||||
params.sampling.top_k = value;
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
@@ -1275,6 +1289,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.top_p = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
@@ -1282,6 +1297,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.min_p = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
@@ -1296,6 +1312,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.xtc_probability = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
@@ -1303,6 +1320,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.xtc_threshold = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
@@ -1321,6 +1339,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
params.sampling.penalty_last_n = value;
|
||||
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
@@ -1328,6 +1347,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.penalty_repeat = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
@@ -1425,6 +1445,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat),
|
||||
[](common_params & params, int value) {
|
||||
params.sampling.mirostat = value;
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
@@ -1432,6 +1453,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.mirostat_eta = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
@@ -1439,6 +1461,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.mirostat_tau = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
@@ -2072,11 +2095,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
add_opt(common_arg(
|
||||
{"-m", "--model"}, "FNAME",
|
||||
ex == LLAMA_EXAMPLE_EXPORT_LORA
|
||||
? std::string("model path from which to load base model")
|
||||
: string_format(
|
||||
"model path (default: `models/$filename` with filename from `--hf-file` "
|
||||
"or `--model-url` if set, otherwise %s)", DEFAULT_MODEL_PATH
|
||||
),
|
||||
? "model path from which to load base model"
|
||||
: "model path to load",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model.path = value;
|
||||
}
|
||||
@@ -2474,13 +2494,41 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--models-dir"}, "PATH",
|
||||
"directory containing models for the router server (default: disabled)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.models_dir = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_DIR"));
|
||||
add_opt(common_arg(
|
||||
{"--models-max"}, "N",
|
||||
string_format("for router server, maximum number of models to load simultaneously (default: %d, 0 = unlimited)", params.models_max),
|
||||
[](common_params & params, int value) {
|
||||
params.models_max = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_MAX"));
|
||||
add_opt(common_arg(
|
||||
{"--no-models-autoload"},
|
||||
"disables automatic loading of models (default: enabled)",
|
||||
[](common_params & params) {
|
||||
params.models_autoload = false;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_MODELS_AUTOLOAD"));
|
||||
add_opt(common_arg(
|
||||
{"--jinja"},
|
||||
"use jinja template for chat (default: disabled)",
|
||||
string_format("use jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"),
|
||||
[](common_params & params) {
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA"));
|
||||
add_opt(common_arg(
|
||||
{"--no-jinja"},
|
||||
string_format("disable jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"),
|
||||
[](common_params & params) {
|
||||
params.use_jinja = false;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_NO_JINJA"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-format"}, "FORMAT",
|
||||
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
|
||||
@@ -2614,7 +2662,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params &, const std::string & value) {
|
||||
common_log_set_file(common_log_main(), value.c_str());
|
||||
}
|
||||
));
|
||||
).set_env("LLAMA_LOG_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"--log-colors"}, "[on|off|auto]",
|
||||
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
|
||||
@@ -2649,7 +2697,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_env("LLAMA_OFFLINE"));
|
||||
add_opt(common_arg(
|
||||
{"-lv", "--verbosity", "--log-verbosity"}, "N",
|
||||
"Set the verbosity threshold. Messages with a higher verbosity will be ignored.",
|
||||
string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n"
|
||||
" - 0: generic output\n"
|
||||
" - 1: error\n"
|
||||
" - 2: warning\n"
|
||||
" - 3: info\n"
|
||||
" - 4: debug\n"
|
||||
"(default: %d)\n", params.verbosity),
|
||||
[](common_params & params, int value) {
|
||||
params.verbosity = value;
|
||||
common_log_set_verbosity_thold(value);
|
||||
|
||||
861
common/chat-parser-xml-toolcall.cpp
Normal file
861
common/chat-parser-xml-toolcall.cpp
Normal file
@@ -0,0 +1,861 @@
|
||||
#include "chat.h"
|
||||
#include "chat-parser.h"
|
||||
#include "common.h"
|
||||
#include "json-partial.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
class xml_toolcall_syntax_exception : public std::runtime_error {
|
||||
public:
|
||||
xml_toolcall_syntax_exception(const std::string & message) : std::runtime_error(message) {}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
inline void sort_uniq(std::vector<T> &vec) {
|
||||
std::sort(vec.begin(), vec.end());
|
||||
vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline bool all_space(const T &str) {
|
||||
return std::all_of(str.begin(), str.end(), [](unsigned char ch) { return std::isspace(ch); });
|
||||
}
|
||||
|
||||
static size_t utf8_truncate_safe(const std::string_view s) {
|
||||
size_t len = s.size();
|
||||
if (len == 0) return 0;
|
||||
size_t i = len;
|
||||
for (size_t back = 0; back < 4 && i > 0; ++back) {
|
||||
--i;
|
||||
unsigned char c = s[i];
|
||||
if ((c & 0x80) == 0) {
|
||||
return len;
|
||||
} else if ((c & 0xC0) == 0xC0) {
|
||||
size_t expected_len = 0;
|
||||
if ((c & 0xE0) == 0xC0) expected_len = 2;
|
||||
else if ((c & 0xF0) == 0xE0) expected_len = 3;
|
||||
else if ((c & 0xF8) == 0xF0) expected_len = 4;
|
||||
else return i;
|
||||
if (len - i >= expected_len) {
|
||||
return len;
|
||||
} else {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
return len - std::min(len, size_t(3));
|
||||
}
|
||||
|
||||
inline void utf8_truncate_safe_resize(std::string &s) {
|
||||
s.resize(utf8_truncate_safe(s));
|
||||
}
|
||||
|
||||
inline std::string_view utf8_truncate_safe_view(const std::string_view s) {
|
||||
return s.substr(0, utf8_truncate_safe(s));
|
||||
}
|
||||
|
||||
static std::optional<common_chat_msg_parser::find_regex_result> try_find_2_literal_splited_by_spaces(common_chat_msg_parser & builder, const std::string & literal1, const std::string & literal2) {
|
||||
if (literal1.size() == 0) return builder.try_find_literal(literal2);
|
||||
const auto saved_pos = builder.pos();
|
||||
while (auto res = builder.try_find_literal(literal1)) {
|
||||
builder.consume_spaces();
|
||||
const auto match_len = std::min(literal2.size(), builder.input().size() - builder.pos());
|
||||
if (builder.input().compare(builder.pos(), match_len, literal2, 0, match_len) == 0) {
|
||||
if (res->prelude.size() != res->groups[0].begin - saved_pos) {
|
||||
res->prelude = builder.str({saved_pos, res->groups[0].begin});
|
||||
}
|
||||
builder.move_to(builder.pos() + match_len);
|
||||
res->groups[0].end = builder.pos();
|
||||
GGML_ASSERT(res->groups[0].begin != res->groups[0].end);
|
||||
return res;
|
||||
}
|
||||
builder.move_to(res->groups[0].begin + 1);
|
||||
}
|
||||
builder.move_to(saved_pos);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
/**
|
||||
* make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||
*/
|
||||
std::string make_gbnf_excluding(std::vector<std::string> forbids) {
|
||||
constexpr auto charclass_escape = [](unsigned char c) -> std::string {
|
||||
if (c == '\\' || c == ']' || c == '^' || c == '-') {
|
||||
std::string s = "\\";
|
||||
s.push_back((char)c);
|
||||
return s;
|
||||
}
|
||||
if (isprint(c)) {
|
||||
return std::string(1, (char)c);
|
||||
}
|
||||
char buf[16];
|
||||
snprintf(buf, 15, "\\x%02X", c);
|
||||
return std::string(buf);
|
||||
};
|
||||
constexpr auto build_expr = [charclass_escape](auto self, const std::vector<std::string>& forbids, int l, int r, int depth) -> std::string {
|
||||
std::vector<std::pair<unsigned char, std::pair<int,int>>> children;
|
||||
int i = l;
|
||||
while (i < r) {
|
||||
const std::string &s = forbids[i];
|
||||
if ((int)s.size() == depth) {
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
unsigned char c = (unsigned char)s[depth];
|
||||
int j = i;
|
||||
while (j < r && (int)forbids[j].size() > depth &&
|
||||
(unsigned char)forbids[j][depth] == c) {
|
||||
++j;
|
||||
}
|
||||
children.push_back({c, {i, j}});
|
||||
i = j;
|
||||
}
|
||||
std::vector<std::string> alts;
|
||||
if (!children.empty()) {
|
||||
std::string cls;
|
||||
for (auto &ch : children) cls += charclass_escape(ch.first);
|
||||
alts.push_back(std::string("[^") + cls + "]");
|
||||
}
|
||||
for (auto &ch : children) {
|
||||
std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1);
|
||||
if (!childExpr.empty()) {
|
||||
std::string quoted_ch = "\"";
|
||||
if (ch.first == '\\') quoted_ch += "\\\\";
|
||||
else if (ch.first == '"') quoted_ch += "\\\"";
|
||||
else if (isprint(ch.first)) quoted_ch.push_back(ch.first);
|
||||
else {
|
||||
char buf[16];
|
||||
snprintf(buf, 15, "\\x%02X", ch.first);
|
||||
quoted_ch += buf;
|
||||
}
|
||||
quoted_ch += "\"";
|
||||
std::string branch = quoted_ch + std::string(" ") + childExpr;
|
||||
alts.push_back(branch);
|
||||
}
|
||||
}
|
||||
if (alts.empty()) return "";
|
||||
std::ostringstream oss;
|
||||
oss << "( ";
|
||||
for (size_t k = 0; k < alts.size(); ++k) {
|
||||
if (k) oss << " | ";
|
||||
oss << alts[k];
|
||||
}
|
||||
oss << " )";
|
||||
return oss.str();
|
||||
};
|
||||
if (forbids.empty()) return "( . )*";
|
||||
sort(forbids.begin(), forbids.end());
|
||||
std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0);
|
||||
if (expr.empty()) {
|
||||
std::string cls;
|
||||
for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]);
|
||||
expr = std::string("( [^") + cls + "] )";
|
||||
}
|
||||
if (forbids.size() == 1)
|
||||
return expr + "*";
|
||||
else
|
||||
return std::string("( ") + expr + " )*";
|
||||
}
|
||||
|
||||
/**
|
||||
* Build grammar for xml-style tool call
|
||||
* form.scope_start and form.scope_end can be empty.
|
||||
* Requires data.format for model-specific hacks.
|
||||
*/
|
||||
void build_grammar_xml_tool_call(common_chat_params & data, const json & tools, const struct xml_tool_call_format & form) {
|
||||
GGML_ASSERT(!form.tool_start.empty());
|
||||
GGML_ASSERT(!form.tool_sep.empty());
|
||||
GGML_ASSERT(!form.key_start.empty());
|
||||
GGML_ASSERT(!form.val_end.empty());
|
||||
GGML_ASSERT(!form.tool_end.empty());
|
||||
|
||||
std::string key_val_sep = form.key_val_sep;
|
||||
if (form.key_val_sep2) {
|
||||
key_val_sep += "\n";
|
||||
key_val_sep += *form.key_val_sep2;
|
||||
}
|
||||
GGML_ASSERT(!key_val_sep.empty());
|
||||
|
||||
if (tools.is_array() && !tools.empty()) {
|
||||
data.grammar = build_grammar([&](const common_grammar_builder &builder) {
|
||||
auto string_arg_val = form.last_val_end ?
|
||||
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end, *form.last_val_end})) :
|
||||
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end}));
|
||||
|
||||
std::vector<std::string> tool_rules;
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
||||
LOG_WRN("Skipping tool without function: %s", tool.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool.at("function");
|
||||
if (!function.contains("name") || !function.at("name").is_string()) {
|
||||
LOG_WRN("Skipping invalid function (invalid name): %s", function.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
|
||||
LOG_WRN("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
|
||||
struct parameter_rule {
|
||||
std::string symbol_name;
|
||||
bool is_required;
|
||||
};
|
||||
std::vector<parameter_rule> arg_rules;
|
||||
if (!parameters.contains("properties") || !parameters.at("properties").is_object()) {
|
||||
LOG_WRN("Skipping invalid function (invalid properties): %s", function.dump(2).c_str());
|
||||
continue;
|
||||
} else {
|
||||
std::vector<std::string> requiredParameters;
|
||||
if (parameters.contains("required")) {
|
||||
try { parameters.at("required").get_to(requiredParameters); }
|
||||
catch (const std::runtime_error&) {
|
||||
LOG_WRN("Invalid function required parameters, ignoring: %s", function.at("required").dump(2).c_str());
|
||||
}
|
||||
}
|
||||
sort_uniq(requiredParameters);
|
||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||
std::string quoted_key = key;
|
||||
bool required = std::binary_search(requiredParameters.begin(), requiredParameters.end(), key);
|
||||
if (form.key_start.back() == '"' && key_val_sep[0] == '"') {
|
||||
quoted_key = gbnf_format_literal(key);
|
||||
quoted_key = quoted_key.substr(1, quoted_key.size() - 2);
|
||||
}
|
||||
arg_rules.push_back(parameter_rule {builder.add_rule("func-" + name + "-kv-" + key,
|
||||
gbnf_format_literal(form.key_start) + " " +
|
||||
gbnf_format_literal(quoted_key) + " " +
|
||||
gbnf_format_literal(key_val_sep) + " " +
|
||||
((value.contains("type") && value["type"].is_string() && value["type"] == "string" && (!form.raw_argval || *form.raw_argval)) ?
|
||||
(form.raw_argval ?
|
||||
string_arg_val :
|
||||
"( " + string_arg_val + " | " + builder.add_schema(name + "-arg-" + key, value) + " )"
|
||||
) :
|
||||
builder.add_schema(name + "-arg-" + key, value)
|
||||
)
|
||||
), required});
|
||||
}
|
||||
}
|
||||
|
||||
auto next_arg_with_sep = builder.add_rule(name + "-last-arg-end", form.last_val_end ? gbnf_format_literal(*form.last_val_end) : gbnf_format_literal(form.val_end));
|
||||
decltype(next_arg_with_sep) next_arg = "\"\"";
|
||||
for (auto i = arg_rules.size() - 1; /* i >= 0 && */ i < arg_rules.size(); --i) {
|
||||
std::string include_this_arg = arg_rules[i].symbol_name + " " + next_arg_with_sep;
|
||||
next_arg = builder.add_rule(name + "-arg-after-" + std::to_string(i), arg_rules[i].is_required ?
|
||||
include_this_arg : "( " + include_this_arg + " ) | " + next_arg
|
||||
);
|
||||
include_this_arg = gbnf_format_literal(form.val_end) + " " + include_this_arg;
|
||||
next_arg_with_sep = builder.add_rule(name + "-arg-after-" + std::to_string(i) + "-with-sep", arg_rules[i].is_required ?
|
||||
include_this_arg : "( " + include_this_arg + " ) | " + next_arg_with_sep
|
||||
);
|
||||
}
|
||||
|
||||
std::string quoted_name = name;
|
||||
if (form.tool_start.back() == '"' && form.tool_sep[0] == '"') {
|
||||
quoted_name = gbnf_format_literal(name);
|
||||
quoted_name = quoted_name.substr(1, quoted_name.size() - 2);
|
||||
}
|
||||
quoted_name = gbnf_format_literal(quoted_name);
|
||||
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
||||
if (data.format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
||||
quoted_name = "\"functions.\" " + quoted_name + " \":\" [0-9]+";
|
||||
}
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
gbnf_format_literal(form.tool_start) + " " +
|
||||
quoted_name + " " +
|
||||
gbnf_format_literal(form.tool_sep) + " " +
|
||||
next_arg
|
||||
));
|
||||
}
|
||||
|
||||
auto tool_call_once = builder.add_rule("root-tool-call-once", string_join(tool_rules, " | "));
|
||||
auto tool_call_more = builder.add_rule("root-tool-call-more", gbnf_format_literal(form.tool_end) + " " + tool_call_once);
|
||||
auto call_end = builder.add_rule("root-call-end", form.last_tool_end ? gbnf_format_literal(*form.last_tool_end) : gbnf_format_literal(form.tool_end));
|
||||
auto tool_call_multiple_with_end = builder.add_rule("root-tool-call-multiple-with-end", tool_call_once + " " + tool_call_more + "* " + call_end);
|
||||
builder.add_rule("root",
|
||||
(form.scope_start.empty() ? "" : gbnf_format_literal(form.scope_start) + " ") +
|
||||
tool_call_multiple_with_end + "?" +
|
||||
(form.scope_end.empty() ? "" : " " + gbnf_format_literal(form.scope_end))
|
||||
);
|
||||
});
|
||||
|
||||
// grammar trigger for tool call
|
||||
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||
* Throws xml_toolcall_syntax_exception if there is invalid syntax and cannot recover the original status for common_chat_msg_parser.
|
||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||
*/
|
||||
inline bool parse_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form) {
|
||||
GGML_ASSERT(!form.tool_start.empty());
|
||||
GGML_ASSERT(!form.key_start.empty());
|
||||
GGML_ASSERT(!form.key_val_sep.empty());
|
||||
GGML_ASSERT(!form.val_end.empty());
|
||||
GGML_ASSERT(!form.tool_end.empty());
|
||||
|
||||
// Helper to choose return false or throw error
|
||||
constexpr auto return_error = [](common_chat_msg_parser & builder, auto &start_pos, const bool &recovery) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call at position: %s\n", gbnf_format_literal(builder.consume_rest().substr(0, 20)).c_str());
|
||||
if (recovery) {
|
||||
builder.move_to(start_pos);
|
||||
return false;
|
||||
} else throw xml_toolcall_syntax_exception("Tool call parsing failed with unrecoverable errors. Try using a grammar to constrain the model’s output.");
|
||||
};
|
||||
// Drop substring from needle to end from a JSON
|
||||
constexpr auto partial_json = [](std::string &json_str, std::string_view needle = "XML_TOOL_CALL_PARTIAL_FLAG") {
|
||||
auto pos = json_str.rfind(needle);
|
||||
if (pos == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
for (auto i = pos + needle.size(); i < json_str.size(); ++i) {
|
||||
unsigned char ch = static_cast<unsigned char>(json_str[i]);
|
||||
if (ch != '\'' && ch != '"' && ch != '}' && ch != ':' && !std::isspace(ch)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (pos != 0 && json_str[pos - 1] == '"') {
|
||||
--pos;
|
||||
}
|
||||
json_str.resize(pos);
|
||||
return true;
|
||||
};
|
||||
// Helper to generate a partial argument JSON
|
||||
constexpr auto gen_partial_json = [partial_json](auto set_partial_arg, auto &arguments, auto &builder, auto &function_name) {
|
||||
auto rest = builder.consume_rest();
|
||||
utf8_truncate_safe_resize(rest);
|
||||
set_partial_arg(rest, "XML_TOOL_CALL_PARTIAL_FLAG");
|
||||
auto tool_str = arguments.dump();
|
||||
if (partial_json(tool_str)) {
|
||||
if (builder.add_tool_call(function_name, "", tool_str)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOG_DBG("Failed to parse partial XML-Style tool call, fallback to non-partial: %s\n", tool_str.c_str());
|
||||
};
|
||||
// Helper to find a close (because there may be form.last_val_end or form.last_tool_end)
|
||||
constexpr auto try_find_close = [](
|
||||
common_chat_msg_parser & builder,
|
||||
const std::string & end,
|
||||
const std::optional<std::string> & alt_end,
|
||||
const std::string & end_next,
|
||||
const std::optional<std::string> & alt_end_next
|
||||
) {
|
||||
auto saved_pos = builder.pos();
|
||||
auto tc = builder.try_find_literal(end);
|
||||
auto val_end_size = end.size();
|
||||
if (alt_end) {
|
||||
auto pos_1 = builder.pos();
|
||||
builder.move_to(saved_pos);
|
||||
auto tc2 = try_find_2_literal_splited_by_spaces(builder, *alt_end, end_next);
|
||||
if (alt_end_next) {
|
||||
builder.move_to(saved_pos);
|
||||
auto tc3 = try_find_2_literal_splited_by_spaces(builder, *alt_end, *alt_end_next);
|
||||
if (tc3 && (!tc2 || tc2->prelude.size() > tc3->prelude.size())) {
|
||||
tc2 = tc3;
|
||||
}
|
||||
}
|
||||
if (tc2 && (!tc || tc->prelude.size() > tc2->prelude.size())) {
|
||||
tc = tc2;
|
||||
tc->groups[0].end = std::min(builder.input().size(), tc->groups[0].begin + alt_end->size());
|
||||
builder.move_to(tc->groups[0].end);
|
||||
val_end_size = alt_end->size();
|
||||
} else {
|
||||
builder.move_to(pos_1);
|
||||
}
|
||||
}
|
||||
return std::make_pair(val_end_size, tc);
|
||||
};
|
||||
// Helper to find a val_end or last_val_end, returns matched pattern size
|
||||
const auto try_find_val_end = [try_find_close, &builder, &form]() {
|
||||
return try_find_close(builder, form.val_end, form.last_val_end, form.tool_end, form.last_tool_end);
|
||||
};
|
||||
// Helper to find a tool_end or last_tool_end, returns matched pattern size
|
||||
const auto try_find_tool_end = [try_find_close, &builder, &form]() {
|
||||
return try_find_close(builder, form.tool_end, form.last_tool_end, form.scope_end, std::nullopt);
|
||||
};
|
||||
|
||||
bool recovery = true;
|
||||
const auto start_pos = builder.pos();
|
||||
if (!all_space(form.scope_start)) {
|
||||
if (auto tc = builder.try_find_literal(form.scope_start)) {
|
||||
if (all_space(tc->prelude)) {
|
||||
if (form.scope_start.size() != tc->groups[0].end - tc->groups[0].begin)
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.scope_start));
|
||||
} else {
|
||||
builder.move_to(start_pos);
|
||||
return false;
|
||||
}
|
||||
} else return false;
|
||||
}
|
||||
while (auto tc = builder.try_find_literal(form.tool_start)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
||||
gbnf_format_literal(form.tool_start).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
||||
break;
|
||||
}
|
||||
|
||||
// Find tool name
|
||||
auto func_name = builder.try_find_literal(all_space(form.tool_sep) ? form.key_start : form.tool_sep);
|
||||
if (!func_name) {
|
||||
auto [sz, tc] = try_find_tool_end();
|
||||
func_name = tc;
|
||||
}
|
||||
if (!func_name) {
|
||||
// Partial tool name not supported
|
||||
throw common_chat_msg_partial_exception("incomplete tool_call");
|
||||
}
|
||||
// If the model generate multiple tool call and the first tool call has no argument
|
||||
if (func_name->prelude.find(form.tool_end) != std::string::npos || (form.last_tool_end ? func_name->prelude.find(*form.last_tool_end) != std::string::npos : false)) {
|
||||
builder.move_to(func_name->groups[0].begin - func_name->prelude.size());
|
||||
auto [sz, tc] = try_find_tool_end();
|
||||
func_name = tc;
|
||||
}
|
||||
|
||||
// Parse tool name
|
||||
builder.move_to(all_space(form.tool_sep) ? func_name->groups[0].begin : func_name->groups[0].end);
|
||||
std::string function_name = string_strip(func_name->prelude);
|
||||
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
||||
if (builder.syntax().format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
||||
if (string_starts_with(function_name, "functions.")) {
|
||||
static const std::regex re(":\\d+$");
|
||||
if (std::regex_search(function_name, re)) {
|
||||
function_name = function_name.substr(10, function_name.rfind(":") - 10);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Argument JSON
|
||||
json arguments = json::object();
|
||||
|
||||
// Helper to generate a partial argument JSON
|
||||
const auto gen_partial_args = [&](auto set_partial_arg) {
|
||||
gen_partial_json(set_partial_arg, arguments, builder, function_name);
|
||||
};
|
||||
|
||||
// Parse all arg_key/arg_value pairs
|
||||
while (auto tc = builder.try_find_literal(form.key_start)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
||||
gbnf_format_literal(form.key_start).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
||||
break;
|
||||
}
|
||||
if (tc->groups[0].end - tc->groups[0].begin != form.key_start.size()) {
|
||||
auto tool_call_arg = arguments.dump();
|
||||
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
||||
tool_call_arg.resize(tool_call_arg.size() - 1);
|
||||
}
|
||||
builder.add_tool_call(function_name, "", tool_call_arg);
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_start));
|
||||
}
|
||||
|
||||
// Parse arg_key
|
||||
auto key_res = builder.try_find_literal(form.key_val_sep);
|
||||
if (!key_res) {
|
||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[rest + needle] = "";});
|
||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.key_val_sep) + " after " + gbnf_format_literal(form.key_start));
|
||||
}
|
||||
if (key_res->groups[0].end - key_res->groups[0].begin != form.key_val_sep.size()) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key_res->prelude + needle] = "";});
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_val_sep));
|
||||
}
|
||||
auto &key = key_res->prelude;
|
||||
recovery = false;
|
||||
|
||||
// Parse arg_value
|
||||
if (form.key_val_sep2) {
|
||||
if (auto tc = builder.try_find_literal(*form.key_val_sep2)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Unexcepted %s between %s and %s\n",
|
||||
gbnf_format_literal(tc->prelude).c_str(),
|
||||
gbnf_format_literal(form.key_val_sep).c_str(),
|
||||
gbnf_format_literal(*form.key_val_sep2).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, false);
|
||||
}
|
||||
if (tc->groups[0].end - tc->groups[0].begin != form.key_val_sep2->size()) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(*form.key_val_sep2));
|
||||
}
|
||||
} else {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(*form.key_val_sep2) + " after " + gbnf_format_literal(form.key_val_sep));
|
||||
}
|
||||
}
|
||||
auto val_start = builder.pos();
|
||||
|
||||
// Test if arg_val is a partial JSON
|
||||
std::optional<common_json> value_json = std::nullopt;
|
||||
if (!form.raw_argval || !*form.raw_argval) {
|
||||
try { value_json = builder.try_consume_json(); }
|
||||
catch (const std::runtime_error&) { builder.move_to(val_start); }
|
||||
// TODO: Delete this when json_partial adds top-level support for null/true/false
|
||||
if (builder.pos() == val_start) {
|
||||
const static std::regex number_regex(R"([0-9-][0-9]*(\.\d*)?([eE][+-]?\d*)?)");
|
||||
builder.consume_spaces();
|
||||
std::string_view sv = utf8_truncate_safe_view(builder.input());
|
||||
sv.remove_prefix(builder.pos());
|
||||
std::string rest = "a";
|
||||
if (sv.size() < 6) rest = sv;
|
||||
if (string_starts_with("null", rest) || string_starts_with("true", rest) || string_starts_with("false", rest) || std::regex_match(sv.begin(), sv.end(), number_regex)) {
|
||||
value_json = {123, {"123", "123"}};
|
||||
builder.consume_rest();
|
||||
} else {
|
||||
builder.move_to(val_start);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If it is a JSON and followed by </arg_value>, parse as json
|
||||
// cannot support streaming because it may be a plain text starting with JSON
|
||||
if (value_json) {
|
||||
auto json_end = builder.pos();
|
||||
builder.consume_spaces();
|
||||
if (builder.pos() == builder.input().size()) {
|
||||
if (form.raw_argval && !*form.raw_argval && (value_json->json.is_string() || value_json->json.is_object() || value_json->json.is_array())) {
|
||||
arguments[key] = value_json->json;
|
||||
auto json_str = arguments.dump();
|
||||
if (!value_json->healing_marker.json_dump_marker.empty()) {
|
||||
GGML_ASSERT(std::string::npos != json_str.rfind(value_json->healing_marker.json_dump_marker));
|
||||
json_str.resize(json_str.rfind(value_json->healing_marker.json_dump_marker));
|
||||
} else {
|
||||
GGML_ASSERT(json_str.back() == '}');
|
||||
json_str.resize(json_str.size() - 1);
|
||||
}
|
||||
builder.add_tool_call(function_name, "", json_str);
|
||||
} else {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
}
|
||||
LOG_DBG("Possible JSON arg_value: %s\n", value_json->json.dump().c_str());
|
||||
throw common_chat_msg_partial_exception("JSON arg_value detected. Waiting for more tokens for validations.");
|
||||
}
|
||||
builder.move_to(json_end);
|
||||
auto [val_end_size, tc] = try_find_val_end();
|
||||
if (tc && all_space(tc->prelude) && value_json->healing_marker.marker.empty()) {
|
||||
if (tc->groups[0].end - tc->groups[0].begin != val_end_size) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
LOG_DBG("Possible terminated JSON arg_value: %s\n", value_json->json.dump().c_str());
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.val_end) + (form.last_val_end ? gbnf_format_literal(*form.last_val_end) : ""));
|
||||
} else arguments[key] = value_json->json;
|
||||
} else builder.move_to(val_start);
|
||||
}
|
||||
|
||||
// If not, parse as plain text
|
||||
if (val_start == builder.pos()) {
|
||||
if (auto [val_end_size, value_plain] = try_find_val_end(); value_plain) {
|
||||
auto &value_str = value_plain->prelude;
|
||||
if (form.trim_raw_argval) value_str = string_strip(value_str);
|
||||
if (value_plain->groups[0].end - value_plain->groups[0].begin != val_end_size) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = value_str + needle;});
|
||||
throw common_chat_msg_partial_exception(
|
||||
"Expected " + gbnf_format_literal(form.val_end) +
|
||||
" after " + gbnf_format_literal(form.key_val_sep) +
|
||||
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
||||
);
|
||||
}
|
||||
arguments[key] = value_str;
|
||||
} else {
|
||||
if (form.trim_raw_argval) {
|
||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = string_strip(rest) + needle;});
|
||||
} else {
|
||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = rest + needle;});
|
||||
}
|
||||
throw common_chat_msg_partial_exception(
|
||||
"Expected " + gbnf_format_literal(form.val_end) +
|
||||
" after " + gbnf_format_literal(form.key_val_sep) +
|
||||
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Consume closing tag
|
||||
if (auto [tool_end_size, tc] = try_find_tool_end(); tc) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||
gbnf_format_literal(form.tool_end).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, recovery);
|
||||
}
|
||||
if (tc->groups[0].end - tc->groups[0].begin == tool_end_size) {
|
||||
// Add the parsed tool call
|
||||
if (!builder.add_tool_call(function_name, "", arguments.dump())) {
|
||||
throw common_chat_msg_partial_exception("Failed to add XML-Style tool call");
|
||||
}
|
||||
recovery = false;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto tool_call_arg = arguments.dump();
|
||||
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
||||
tool_call_arg.resize(tool_call_arg.size() - 1);
|
||||
}
|
||||
builder.add_tool_call(function_name, "", tool_call_arg);
|
||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.tool_end) + " after " + gbnf_format_literal(form.val_end));
|
||||
}
|
||||
if (auto tc = builder.try_find_literal(form.scope_end)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||
gbnf_format_literal(form.scope_end).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, recovery);
|
||||
}
|
||||
} else {
|
||||
if (all_space(form.scope_end)) return true;
|
||||
builder.consume_spaces();
|
||||
if (builder.pos() == builder.input().size())
|
||||
throw common_chat_msg_partial_exception("incomplete tool calls");
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||
gbnf_format_literal(form.scope_end).c_str(),
|
||||
gbnf_format_literal(builder.consume_rest()).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, recovery);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||
* May cause std::runtime_error if there is invalid syntax because partial valid tool call is already sent out to client.
|
||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||
*/
|
||||
bool common_chat_msg_parser::try_consume_xml_tool_calls(const struct xml_tool_call_format & form) {
|
||||
auto pos = pos_;
|
||||
auto tsize = result_.tool_calls.size();
|
||||
try { return parse_xml_tool_calls(*this, form); }
|
||||
catch (const xml_toolcall_syntax_exception&) {}
|
||||
move_to(pos);
|
||||
result_.tool_calls.resize(tsize);
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse content uses reasoning and XML-Style tool call
|
||||
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
|
||||
*/
|
||||
inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>") {
|
||||
constexpr auto rstrip = [](std::string &s) {
|
||||
s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base()));
|
||||
};
|
||||
// Erase substring from l to r, along with additional spaces nearby
|
||||
constexpr auto erase_spaces = [](auto &str, size_t l, size_t r) {
|
||||
while (/* l > -1 && */ --l < str.size() && std::isspace(static_cast<unsigned char>(str[l])));
|
||||
++l;
|
||||
while (++r < str.size() && std::isspace(static_cast<unsigned char>(str[r])));
|
||||
if (l < r) str[l] = '\n';
|
||||
if (l + 1 < r) str[l + 1] = '\n';
|
||||
if (l != 0) l += 2;
|
||||
str.erase(l, r - l);
|
||||
return l;
|
||||
};
|
||||
constexpr auto trim_suffix = [](std::string &content, std::initializer_list<std::string_view> list) {
|
||||
auto best_match = content.size();
|
||||
for (auto pattern: list) {
|
||||
if (pattern.size() == 0) continue;
|
||||
for (auto match_idx = content.size() - std::min(pattern.size(), content.size()); content.size() > match_idx; match_idx++) {
|
||||
auto match_len = content.size() - match_idx;
|
||||
if (content.compare(match_idx, match_len, pattern.data(), match_len) == 0 && best_match > match_idx) {
|
||||
best_match = match_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (content.size() > best_match) {
|
||||
content.erase(best_match);
|
||||
}
|
||||
};
|
||||
const auto trim_potential_partial_word = [&start_think, &end_think, &form, trim_suffix](std::string &content) {
|
||||
return trim_suffix(content, {
|
||||
start_think, end_think, form.scope_start, form.tool_start, form.tool_sep, form.key_start,
|
||||
form.key_val_sep, form.key_val_sep2 ? form.key_val_sep2->c_str() : "",
|
||||
form.val_end, form.last_val_end ? form.last_val_end->c_str() : "",
|
||||
form.tool_end, form.last_tool_end ? form.last_tool_end->c_str() : "",
|
||||
form.scope_end
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
// Trim leading spaces without affecting keyword matching
|
||||
static const common_regex spaces_regex("\\s*");
|
||||
{
|
||||
auto tc = builder.consume_regex(spaces_regex);
|
||||
auto spaces = builder.str(tc.groups[0]);
|
||||
auto s1 = spaces.size();
|
||||
trim_potential_partial_word(spaces);
|
||||
auto s2 = spaces.size();
|
||||
builder.move_to(builder.pos() - (s1 - s2));
|
||||
}
|
||||
|
||||
// Parse content
|
||||
bool reasoning_unclosed = builder.syntax().thinking_forced_open;
|
||||
std::string unclosed_reasoning_content("");
|
||||
for (;;) {
|
||||
auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start);
|
||||
std::string content;
|
||||
std::string tool_call_start;
|
||||
|
||||
if (tc) {
|
||||
content = std::move(tc->prelude);
|
||||
tool_call_start = builder.str(tc->groups[0]);
|
||||
LOG_DBG("Matched tool start: %s\n", gbnf_format_literal(tool_call_start).c_str());
|
||||
} else {
|
||||
content = builder.consume_rest();
|
||||
utf8_truncate_safe_resize(content);
|
||||
}
|
||||
|
||||
// Handle unclosed think block
|
||||
if (reasoning_unclosed) {
|
||||
if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) {
|
||||
unclosed_reasoning_content += content;
|
||||
if (form.allow_toolcall_in_think) {
|
||||
builder.move_to(tc->groups[0].begin);
|
||||
if (!builder.try_consume_xml_tool_calls(form)) {
|
||||
unclosed_reasoning_content += tool_call_start;
|
||||
builder.move_to(tc->groups[0].end);
|
||||
}
|
||||
} else {
|
||||
unclosed_reasoning_content += tool_call_start;
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
reasoning_unclosed = false;
|
||||
std::string reasoning_content;
|
||||
if (pos == std::string::npos) {
|
||||
reasoning_content = std::move(content);
|
||||
} else {
|
||||
reasoning_content = content.substr(0, pos);
|
||||
content.erase(0, pos + end_think.size());
|
||||
}
|
||||
if (builder.pos() == builder.input().size() && all_space(content)) {
|
||||
rstrip(reasoning_content);
|
||||
trim_potential_partial_word(reasoning_content);
|
||||
rstrip(reasoning_content);
|
||||
if (reasoning_content.empty()) {
|
||||
rstrip(unclosed_reasoning_content);
|
||||
trim_potential_partial_word(unclosed_reasoning_content);
|
||||
rstrip(unclosed_reasoning_content);
|
||||
if (unclosed_reasoning_content.empty()) continue;
|
||||
}
|
||||
}
|
||||
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
|
||||
builder.add_content(start_think);
|
||||
builder.add_content(unclosed_reasoning_content);
|
||||
builder.add_content(reasoning_content);
|
||||
if (builder.pos() != builder.input().size() || !all_space(content))
|
||||
builder.add_content(end_think);
|
||||
} else {
|
||||
builder.add_reasoning_content(unclosed_reasoning_content);
|
||||
builder.add_reasoning_content(reasoning_content);
|
||||
}
|
||||
unclosed_reasoning_content.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Handle multiple think block
|
||||
bool toolcall_in_think = false;
|
||||
for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) {
|
||||
if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) {
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||
auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size());
|
||||
builder.add_reasoning_content(reasoning_content);
|
||||
think_start = erase_spaces(content, think_start, think_end + end_think.size() - 1);
|
||||
} else {
|
||||
think_start = think_end + end_think.size() - 1;
|
||||
}
|
||||
} else {
|
||||
// This <tool_call> start is in thinking block, skip this tool call
|
||||
auto pos = think_start + start_think.size();
|
||||
unclosed_reasoning_content = content.substr(pos) + tool_call_start;
|
||||
reasoning_unclosed = true;
|
||||
content.resize(think_start);
|
||||
toolcall_in_think = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||
rstrip(content);
|
||||
// Handle unclosed </think> token from content: delete all </think> token
|
||||
if (auto pos = content.rfind(end_think); pos != std::string::npos) {
|
||||
while (pos != std::string::npos) {
|
||||
pos = erase_spaces(content, pos, pos + end_think.size() - 1);
|
||||
pos = content.rfind(end_think, pos);
|
||||
}
|
||||
}
|
||||
// Strip if needed
|
||||
if (content.size() > 0 && std::isspace(static_cast<unsigned char>(content[0]))) {
|
||||
content = string_strip(content);
|
||||
}
|
||||
}
|
||||
|
||||
// remove potential partial suffix
|
||||
if (content.size() > 0 && builder.pos() == builder.input().size() && unclosed_reasoning_content.empty()) {
|
||||
rstrip(content);
|
||||
trim_potential_partial_word(content);
|
||||
rstrip(content);
|
||||
}
|
||||
|
||||
// Add content
|
||||
if (content.size() != 0) {
|
||||
// If there are multiple content blocks
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) {
|
||||
builder.add_content("\n\n");
|
||||
}
|
||||
builder.add_content(content);
|
||||
}
|
||||
|
||||
// This <tool_call> start is in thinking block, skip this tool call
|
||||
if (toolcall_in_think && !form.allow_toolcall_in_think) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// There is no tool call and all content is parsed
|
||||
if (!tc) {
|
||||
GGML_ASSERT(builder.pos() == builder.input().size());
|
||||
GGML_ASSERT(unclosed_reasoning_content.empty());
|
||||
GGML_ASSERT(!reasoning_unclosed);
|
||||
break;
|
||||
}
|
||||
|
||||
builder.move_to(tc->groups[0].begin);
|
||||
if (builder.try_consume_xml_tool_calls(form)) {
|
||||
auto end_of_tool = builder.pos();
|
||||
builder.consume_spaces();
|
||||
if (builder.pos() != builder.input().size()) {
|
||||
builder.move_to(end_of_tool);
|
||||
if (!builder.result().content.empty()) {
|
||||
builder.add_content("\n\n");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static const common_regex next_char_regex(".");
|
||||
auto c = builder.str(builder.consume_regex(next_char_regex).groups[0]);
|
||||
rstrip(c);
|
||||
builder.add_content(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse content uses reasoning and XML-Style tool call
|
||||
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
|
||||
*/
|
||||
void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) {
|
||||
parse_msg_with_xml_tool_calls(*this, form, start_think, end_think);
|
||||
}
|
||||
45
common/chat-parser-xml-toolcall.h
Normal file
45
common/chat-parser-xml-toolcall.h
Normal file
@@ -0,0 +1,45 @@
|
||||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
||||
// Sample config:
|
||||
// MiniMax-M2 (left): <minimax:tool_call>\n<invoke name="tool-name">\n<parameter name="key">value</parameter>\n...</invoke>\n...</minimax:tool_call>
|
||||
// GLM 4.5 (right): <tool_call>function_name\n<arg_key>key</arg_key>\n<arg_value>value</arg_value>\n</tool_call>
|
||||
struct xml_tool_call_format {
|
||||
std::string scope_start; // <minimax:tool_call>\n // \n // can be empty
|
||||
std::string tool_start; // <invoke name=\" // <tool_call>
|
||||
std::string tool_sep; // \">\n // \n // can be empty only for parse_xml_tool_calls
|
||||
std::string key_start; // <parameter name=\" // <arg_key>
|
||||
std::string key_val_sep; // \"> // </arg_key>\n<arg_value>
|
||||
std::string val_end; // </parameter>\n // </arg_value>\n
|
||||
std::string tool_end; // </invoke>\n // </tool_call>\n
|
||||
std::string scope_end; // </minimax:tool_call> // // can be empty
|
||||
// Set this if there can be dynamic spaces inside key_val_sep.
|
||||
// e.g. key_val_sep=</arg_key> key_val_sep2=<arg_value> for GLM4.5
|
||||
std::optional<std::string> key_val_sep2 = std::nullopt;
|
||||
// Set true if argval should only be raw string. e.g. Hello "world" hi
|
||||
// Set false if argval should only be json string. e.g. "Hello \"world\" hi"
|
||||
// Defaults to std::nullopt, both will be allowed.
|
||||
std::optional<bool> raw_argval = std::nullopt;
|
||||
std::optional<std::string> last_val_end = std::nullopt;
|
||||
std::optional<std::string> last_tool_end = std::nullopt;
|
||||
bool trim_raw_argval = false;
|
||||
bool allow_toolcall_in_think = false; // TODO: UNTESTED!!!
|
||||
};
|
||||
|
||||
// make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||
std::string make_gbnf_excluding(std::vector<std::string> forbids);
|
||||
|
||||
/**
|
||||
* Build grammar for xml-style tool call
|
||||
* form.scope_start and form.scope_end can be empty.
|
||||
* Requires data.format for model-specific hacks.
|
||||
*/
|
||||
void build_grammar_xml_tool_call(common_chat_params & data, const nlohmann::ordered_json & tools, const struct xml_tool_call_format & form);
|
||||
@@ -13,6 +13,120 @@
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder,
|
||||
const common_regex & prefix,
|
||||
size_t rstrip_prefix = 0) {
|
||||
static const std::vector<std::vector<std::string>> args_paths = { { "arguments" } };
|
||||
if (auto res = builder.try_find_regex(prefix)) {
|
||||
builder.move_back(rstrip_prefix);
|
||||
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
|
||||
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call array");
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
}
|
||||
|
||||
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
|
||||
std::string arguments;
|
||||
if (builder.is_partial()) {
|
||||
arguments = (json{
|
||||
{ "code", code + builder.healing_marker() }
|
||||
})
|
||||
.dump();
|
||||
auto idx = arguments.find(builder.healing_marker());
|
||||
if (idx != std::string::npos) {
|
||||
arguments.resize(idx);
|
||||
}
|
||||
} else {
|
||||
arguments = (json{
|
||||
{ "code", code }
|
||||
})
|
||||
.dump();
|
||||
}
|
||||
return arguments;
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||
* Aggregates the prefix, suffix and in-between text into the content.
|
||||
*/
|
||||
static void parse_json_tool_calls(
|
||||
common_chat_msg_parser & builder,
|
||||
const std::optional<common_regex> & block_open,
|
||||
const std::optional<common_regex> & function_regex_start_only,
|
||||
const std::optional<common_regex> & function_regex,
|
||||
const common_regex & close_regex,
|
||||
const std::optional<common_regex> & block_close,
|
||||
bool allow_raw_python = false,
|
||||
const std::function<std::string(const common_chat_msg_parser::find_regex_result & fres)> & get_function_name =
|
||||
nullptr) {
|
||||
auto parse_tool_calls = [&]() {
|
||||
size_t from = std::string::npos;
|
||||
auto first = true;
|
||||
while (true) {
|
||||
auto start_pos = builder.pos();
|
||||
auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) :
|
||||
function_regex ? builder.try_find_regex(*function_regex, from) :
|
||||
std::nullopt;
|
||||
|
||||
if (res) {
|
||||
std::string name;
|
||||
if (get_function_name) {
|
||||
name = get_function_name(*res);
|
||||
} else {
|
||||
GGML_ASSERT(res->groups.size() == 2);
|
||||
name = builder.str(res->groups[1]);
|
||||
}
|
||||
first = false;
|
||||
if (name.empty()) {
|
||||
// get_function_name signalled us that we should skip this match and treat it as content.
|
||||
from = res->groups[0].begin + 1;
|
||||
continue;
|
||||
}
|
||||
from = std::string::npos;
|
||||
|
||||
auto maybe_raw_python = name == "python" && allow_raw_python;
|
||||
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
|
||||
if (auto arguments = builder.try_consume_json_with_dumped_args({ {} })) {
|
||||
if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_regex(close_regex);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (maybe_raw_python) {
|
||||
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
|
||||
if (!builder.add_tool_call(name, "", arguments)) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
return;
|
||||
}
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
} else {
|
||||
builder.move_to(start_pos);
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (block_close) {
|
||||
builder.consume_regex(*block_close);
|
||||
}
|
||||
builder.consume_spaces();
|
||||
builder.add_content(builder.consume_rest());
|
||||
};
|
||||
if (block_open) {
|
||||
if (auto res = builder.try_find_regex(*block_open)) {
|
||||
parse_tool_calls();
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
} else {
|
||||
parse_tool_calls();
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
|
||||
: input_(input), is_partial_(is_partial), syntax_(syntax)
|
||||
{
|
||||
@@ -532,3 +646,857 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
|
||||
void common_chat_msg_parser::clear_tools() {
|
||||
result_.tool_calls.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* All common_chat_parse_* moved from chat.cpp to chat-parser.cpp below
|
||||
* to reduce incremental compile time for parser changes.
|
||||
*/
|
||||
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
static const std::vector<std::vector<std::string>> content_paths = {
|
||||
{"response"},
|
||||
};
|
||||
static const std::vector<std::vector<std::string>> args_paths = {
|
||||
{"tool_call", "arguments"},
|
||||
{"tool_calls", "arguments"},
|
||||
};
|
||||
auto data = builder.consume_json_with_dumped_args(args_paths, content_paths);
|
||||
if (data.value.contains("tool_calls")) {
|
||||
if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool calls");
|
||||
}
|
||||
} else if (data.value.contains("tool_call")) {
|
||||
if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
} else if (data.value.contains("response")) {
|
||||
const auto & response = data.value.at("response");
|
||||
builder.add_content(response.is_string() ? response.template get<std::string>() : response.dump(2));
|
||||
if (data.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete response");
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON");
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
|
||||
parse_prefixed_json_tool_call_array(builder, prefix);
|
||||
}
|
||||
|
||||
static void common_chat_parse_magistral(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("[THINK]", "[/THINK]");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
|
||||
parse_prefixed_json_tool_call_array(builder, prefix);
|
||||
}
|
||||
|
||||
static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
|
||||
|
||||
static const common_regex start_action_regex("<\\|START_ACTION\\|>");
|
||||
static const common_regex end_action_regex("<\\|END_ACTION\\|>");
|
||||
static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
|
||||
static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
|
||||
|
||||
if (auto res = builder.try_find_regex(start_action_regex)) {
|
||||
// If we didn't extract thoughts, prelude includes them.
|
||||
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
|
||||
for (const auto & tool_call : tool_calls.value) {
|
||||
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
|
||||
std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
|
||||
std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
|
||||
if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
}
|
||||
if (tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_regex(end_action_regex);
|
||||
} else if (auto res = builder.try_find_regex(start_response_regex)) {
|
||||
if (!builder.try_find_regex(end_response_regex)) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
throw common_chat_msg_partial_exception(end_response_regex.str());
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex function_regex(
|
||||
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
|
||||
static const common_regex close_regex("\\}\\s*");
|
||||
|
||||
static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(");
|
||||
static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*");
|
||||
|
||||
if (with_builtin_tools) {
|
||||
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
|
||||
if (auto res = builder.try_find_regex(builtin_call_regex)) {
|
||||
auto fun_res = builder.consume_regex(function_name_regex);
|
||||
auto function_name = builder.str(fun_res.groups[1]);
|
||||
|
||||
common_healing_marker healing_marker;
|
||||
json args = json::object();
|
||||
while (true) {
|
||||
if (auto arg_res = builder.try_consume_regex(arg_name_regex)) {
|
||||
auto arg_name = builder.str(arg_res->groups[1]);
|
||||
auto partial = builder.consume_json();
|
||||
args[arg_name] = partial.json;
|
||||
healing_marker.marker = partial.healing_marker.marker;
|
||||
healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker;
|
||||
builder.consume_spaces();
|
||||
if (!builder.try_consume_literal(",")) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
builder.consume_literal(")");
|
||||
builder.consume_spaces();
|
||||
|
||||
auto arguments = args.dump();
|
||||
if (!builder.add_tool_call(function_name, "", arguments)) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
/* block_open= */ std::nullopt,
|
||||
/* function_regex_start_only= */ function_regex,
|
||||
/* function_regex= */ std::nullopt,
|
||||
close_regex,
|
||||
std::nullopt);
|
||||
|
||||
}
|
||||
|
||||
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)");
|
||||
static const common_regex tool_calls_end("<|tool▁calls▁end|>");
|
||||
static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||
static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
|
||||
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
/* block_open= */ tool_calls_begin,
|
||||
/* function_regex_start_only= */ std::nullopt,
|
||||
function_regex,
|
||||
close_regex,
|
||||
tool_calls_end);
|
||||
}
|
||||
|
||||
static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) {
|
||||
static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)");
|
||||
|
||||
static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>");
|
||||
static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)");
|
||||
static const common_regex tool_calls_end("<|tool▁calls▁end|>");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
LOG_DBG("%s: not parse_tool_calls\n", __func__);
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_DBG("%s: parse_tool_calls\n", __func__);
|
||||
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
/* block_open= */ tool_calls_begin,
|
||||
/* function_regex_start_only= */ std::nullopt,
|
||||
function_regex,
|
||||
close_regex,
|
||||
tool_calls_end);
|
||||
}
|
||||
|
||||
static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
|
||||
// DeepSeek V3.1 outputs reasoning content between "<think>" and "</think>" tags, followed by regular content
|
||||
// First try to parse using the standard reasoning parsing method
|
||||
LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
|
||||
|
||||
auto start_pos = builder.pos();
|
||||
auto found_end_think = builder.try_find_literal("</think>");
|
||||
builder.move_to(start_pos);
|
||||
|
||||
if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
|
||||
LOG_DBG("%s: no end_think, not partial, adding content\n", __func__);
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
} else if (builder.try_parse_reasoning("<think>", "</think>")) {
|
||||
// If reasoning was parsed successfully, the remaining content is regular content
|
||||
LOG_DBG("%s: parsed reasoning, adding content\n", __func__);
|
||||
// </think><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|>
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
} else {
|
||||
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
|
||||
LOG_DBG("%s: reasoning_format none, adding content\n", __func__);
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
return;
|
||||
}
|
||||
// If no reasoning tags found, check if we should treat everything as reasoning
|
||||
if (builder.syntax().thinking_forced_open) {
|
||||
// If thinking is forced open but no tags found, treat everything as reasoning
|
||||
LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__);
|
||||
builder.add_reasoning_content(builder.consume_rest());
|
||||
} else {
|
||||
LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__);
|
||||
// <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|>
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<minimax:tool_call>",
|
||||
/* form.tool_start = */ "<invoke name=\"",
|
||||
/* form.tool_sep = */ "\">",
|
||||
/* form.key_start = */ "<parameter name=\"",
|
||||
/* form.key_val_sep = */ "\">",
|
||||
/* form.val_end = */ "</parameter>",
|
||||
/* form.tool_end = */ "</invoke>",
|
||||
/* form.scope_end = */ "</minimax:tool_call>",
|
||||
};
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_call>";
|
||||
form.tool_start = "<function=";
|
||||
form.tool_sep = ">";
|
||||
form.key_start = "<parameter=";
|
||||
form.key_val_sep = ">";
|
||||
form.val_end = "</parameter>";
|
||||
form.tool_end = "</function>";
|
||||
form.scope_end = "</tool_call>";
|
||||
form.trim_raw_argval = true;
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||
}
|
||||
|
||||
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<|tool_calls_section_begin|>";
|
||||
form.tool_start = "<|tool_call_begin|>";
|
||||
form.tool_sep = "<|tool_call_argument_begin|>{";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}<|tool_call_end|>";
|
||||
form.scope_end = "<|tool_calls_section_end|>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_calls>[";
|
||||
form.tool_start = "{\"name\": \"";
|
||||
form.tool_sep = "\", \"arguments\": {";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}, ";
|
||||
form.scope_end = "]</tool_calls>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
form.last_tool_end = "}";
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<thinking>", "</thinking>");
|
||||
}
|
||||
|
||||
static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "";
|
||||
form.tool_start = "<tool_call>\n{\"name\": \"";
|
||||
form.tool_sep = "\", \"arguments\": {";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}\n</tool_call>";
|
||||
form.scope_end = "";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||
}
|
||||
|
||||
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
|
||||
static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))";
|
||||
static const std::string recipient("(?: to=functions\\.([^<\\s]+))");
|
||||
|
||||
static const common_regex start_regex("<\\|start\\|>assistant");
|
||||
static const common_regex analysis_regex("<\\|channel\\|>analysis");
|
||||
static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?");
|
||||
static const common_regex preamble_regex("<\\|channel\\|>commentary");
|
||||
static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?");
|
||||
static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?");
|
||||
|
||||
auto consume_end = [&](bool include_end = false) {
|
||||
if (auto res = builder.try_find_literal("<|end|>")) {
|
||||
return res->prelude + (include_end ? builder.str(res->groups[0]) : "");
|
||||
}
|
||||
return builder.consume_rest();
|
||||
};
|
||||
|
||||
auto handle_tool_call = [&](const std::string & name) {
|
||||
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
|
||||
if (builder.syntax().parse_tool_calls) {
|
||||
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
} else if (args->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional<common_regex_match> {
|
||||
auto match = regex.search(input, 0, true);
|
||||
if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) {
|
||||
return match;
|
||||
}
|
||||
return std::nullopt;
|
||||
};
|
||||
|
||||
do {
|
||||
auto header_start_pos = builder.pos();
|
||||
auto content_start = builder.try_find_literal("<|message|>");
|
||||
if (!content_start) {
|
||||
throw common_chat_msg_partial_exception("incomplete header");
|
||||
}
|
||||
|
||||
auto header = content_start->prelude;
|
||||
|
||||
if (auto match = regex_match(tool_call1_regex, header)) {
|
||||
auto group = match->groups[1];
|
||||
auto name = header.substr(group.begin, group.end - group.begin);
|
||||
handle_tool_call(name);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto match = regex_match(tool_call2_regex, header)) {
|
||||
auto group = match->groups[2];
|
||||
auto name = header.substr(group.begin, group.end - group.begin);
|
||||
handle_tool_call(name);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (regex_match(analysis_regex, header)) {
|
||||
builder.move_to(header_start_pos);
|
||||
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
|
||||
builder.add_content(consume_end(true));
|
||||
} else {
|
||||
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) {
|
||||
builder.add_content(consume_end());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Possibly a malformed message, attempt to recover by rolling
|
||||
// back to pick up the next <|start|>
|
||||
LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str());
|
||||
builder.move_to(header_start_pos);
|
||||
} while (builder.try_find_regex(start_regex, std::string::npos, false));
|
||||
|
||||
auto remaining = builder.consume_rest();
|
||||
if (!remaining.empty()) {
|
||||
LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "",
|
||||
/* form.tool_start = */ "<tool_call>",
|
||||
/* form.tool_sep = */ "",
|
||||
/* form.key_start = */ "<arg_key>",
|
||||
/* form.key_val_sep = */ "</arg_key>",
|
||||
/* form.val_end = */ "</arg_value>",
|
||||
/* form.tool_end = */ "</tool_call>",
|
||||
/* form.scope_end = */ "",
|
||||
/* form.key_val_sep2 = */ "<arg_value>",
|
||||
};
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
static const common_regex prefix(regex_escape(" functools["));
|
||||
parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
|
||||
}
|
||||
|
||||
static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) {
|
||||
static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))");
|
||||
static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))");
|
||||
static const common_regex close_regex(R"(\s*)");
|
||||
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
std::nullopt,
|
||||
function_regex_start_only,
|
||||
function_regex,
|
||||
close_regex,
|
||||
std::nullopt,
|
||||
/* allow_raw_python= */ true,
|
||||
/* get_function_name= */ [&](const auto & res) -> std::string {
|
||||
auto at_start = res.groups[0].begin == 0;
|
||||
auto name = builder.str(res.groups[1]);
|
||||
if (!name.empty() && name.back() == '{') {
|
||||
// Unconsume the opening brace '{' to ensure the JSON parsing goes well.
|
||||
builder.move_back(1);
|
||||
}
|
||||
auto idx = name.find_last_not_of("\n{");
|
||||
name = name.substr(0, idx + 1);
|
||||
if (at_start && name == "all") {
|
||||
return "";
|
||||
}
|
||||
return name;
|
||||
});
|
||||
}
|
||||
|
||||
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
|
||||
|
||||
static const common_regex function_regex(R"(<function=(\w+)>)");
|
||||
static const common_regex close_regex(R"(</function>)");
|
||||
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
/* block_open= */ std::nullopt,
|
||||
/* function_regex_start_only= */ std::nullopt,
|
||||
function_regex,
|
||||
close_regex,
|
||||
std::nullopt);
|
||||
|
||||
if (auto res = builder.try_find_regex(python_tag_regex)) {
|
||||
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
|
||||
builder.add_tool_call("python", "", arguments);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex open_regex(
|
||||
"(?:"
|
||||
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
|
||||
"(" // match 2 (open_tag)
|
||||
"<tool_call>"
|
||||
"|<function_call>"
|
||||
"|<tool>"
|
||||
"|<tools>"
|
||||
"|<response>"
|
||||
"|<json>"
|
||||
"|<xml>"
|
||||
"|<JSON>"
|
||||
")?"
|
||||
"(\\s*\\{\\s*\"name\")" // match 3 (named tool call)
|
||||
")"
|
||||
"|<function=([^>]+)>" // match 4 (function name)
|
||||
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
|
||||
);
|
||||
|
||||
while (auto res = builder.try_find_regex(open_regex)) {
|
||||
const auto & block_start = res->groups[1];
|
||||
std::string block_end = block_start.empty() ? "" : "```";
|
||||
|
||||
const auto & open_tag = res->groups[2];
|
||||
std::string close_tag;
|
||||
|
||||
if (!res->groups[3].empty()) {
|
||||
builder.move_to(res->groups[3].begin);
|
||||
close_tag = open_tag.empty() ? "" : "</" + builder.str(open_tag).substr(1);
|
||||
|
||||
if (auto tool_call = builder.try_consume_json_with_dumped_args({{"arguments"}})) {
|
||||
if (!builder.add_tool_call(tool_call->value) || tool_call->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_spaces();
|
||||
builder.consume_literal(close_tag);
|
||||
builder.consume_spaces();
|
||||
if (!block_end.empty()) {
|
||||
builder.consume_literal(block_end);
|
||||
builder.consume_spaces();
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("failed to parse tool call");
|
||||
}
|
||||
} else {
|
||||
auto function_name = builder.str(res->groups[4]);
|
||||
if (function_name.empty()) {
|
||||
function_name = builder.str(res->groups[5]);
|
||||
}
|
||||
GGML_ASSERT(!function_name.empty());
|
||||
|
||||
close_tag = "</function>";
|
||||
|
||||
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
|
||||
if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_spaces();
|
||||
builder.consume_literal(close_tag);
|
||||
builder.consume_spaces();
|
||||
if (!block_end.empty()) {
|
||||
builder.consume_literal(block_end);
|
||||
builder.consume_spaces();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags
|
||||
static const common_regex start_think_regex(regex_escape("<think>"));
|
||||
static const common_regex end_think_regex(regex_escape("</think>"));
|
||||
// Granite models output partial tokens such as "<" and "<think".
|
||||
// By leveraging try_consume_regex()/try_find_regex() throwing
|
||||
// common_chat_msg_partial_exception for these partial tokens,
|
||||
// processing is interrupted and the tokens are not passed to add_content().
|
||||
if (auto res = builder.try_consume_regex(start_think_regex)) {
|
||||
// Restore position for try_parse_reasoning()
|
||||
builder.move_to(res->groups[0].begin);
|
||||
builder.try_find_regex(end_think_regex, std::string::npos, false);
|
||||
// Restore position for try_parse_reasoning()
|
||||
builder.move_to(res->groups[0].begin);
|
||||
}
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
|
||||
// Parse response tags
|
||||
static const common_regex start_response_regex(regex_escape("<response>"));
|
||||
static const common_regex end_response_regex(regex_escape("</response>"));
|
||||
// Granite models output partial tokens such as "<" and "<response".
|
||||
// Same hack as reasoning parsing.
|
||||
if (builder.try_consume_regex(start_response_regex)) {
|
||||
builder.try_find_regex(end_response_regex);
|
||||
}
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// Look for tool calls
|
||||
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
|
||||
if (auto res = builder.try_find_regex(tool_call_regex)) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
// Expect JSON array of tool calls
|
||||
if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) {
|
||||
if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// Look for tool calls
|
||||
static const common_regex tool_call_regex(regex_escape("<TOOLCALL>"));
|
||||
if (auto res = builder.try_find_regex(tool_call_regex)) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
// Expect JSON array of tool calls
|
||||
auto tool_calls_data = builder.consume_json();
|
||||
if (tool_calls_data.json.is_array()) {
|
||||
if (!builder.try_consume_literal("</TOOLCALL>")) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
builder.add_tool_calls(tool_calls_data.json);
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
}
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags
|
||||
builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>");
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// Look for tool calls
|
||||
static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>"));
|
||||
if (auto res = builder.try_find_regex(tool_call_regex)) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
auto tool_calls_data = builder.consume_json();
|
||||
if (tool_calls_data.json.is_array()) {
|
||||
builder.consume_spaces();
|
||||
if (!builder.try_consume_literal("<|tools_suffix|>")) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
for (const auto & value : tool_calls_data.json) {
|
||||
if (value.is_object()) {
|
||||
builder.add_tool_call_short_form(value);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
}
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
|
||||
static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
|
||||
static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
|
||||
static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));
|
||||
|
||||
// Loop through all tool calls
|
||||
while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
// Parse JSON array format: [{"name": "...", "arguments": {...}}]
|
||||
auto tool_calls_data = builder.consume_json();
|
||||
|
||||
// Consume end marker
|
||||
builder.consume_spaces();
|
||||
if (!builder.try_consume_regex(tool_call_end_regex)) {
|
||||
throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
|
||||
}
|
||||
|
||||
// Process each tool call in the array
|
||||
if (tool_calls_data.json.is_array()) {
|
||||
for (const auto & tool_call : tool_calls_data.json) {
|
||||
if (!tool_call.is_object()) {
|
||||
throw common_chat_msg_partial_exception("Tool call must be an object");
|
||||
}
|
||||
|
||||
if (!tool_call.contains("name")) {
|
||||
throw common_chat_msg_partial_exception("Tool call missing 'name' field");
|
||||
}
|
||||
|
||||
std::string function_name = tool_call.at("name");
|
||||
std::string arguments = "{}";
|
||||
|
||||
if (tool_call.contains("arguments")) {
|
||||
if (tool_call.at("arguments").is_object()) {
|
||||
arguments = tool_call.at("arguments").dump();
|
||||
} else if (tool_call.at("arguments").is_string()) {
|
||||
arguments = tool_call.at("arguments");
|
||||
}
|
||||
}
|
||||
|
||||
if (!builder.add_tool_call(function_name, "", arguments)) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
|
||||
}
|
||||
|
||||
// Consume any trailing whitespace after this tool call
|
||||
builder.consume_spaces();
|
||||
}
|
||||
|
||||
// Consume any remaining content after all tool calls
|
||||
auto remaining = builder.consume_rest();
|
||||
if (!string_strip(remaining).empty()) {
|
||||
builder.add_content(remaining);
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<seed:tool_call>",
|
||||
/* form.tool_start = */ "<function=",
|
||||
/* form.tool_sep = */ ">",
|
||||
/* form.key_start = */ "<parameter=",
|
||||
/* form.key_val_sep = */ ">",
|
||||
/* form.val_end = */ "</parameter>",
|
||||
/* form.tool_end = */ "</function>",
|
||||
/* form.scope_end = */ "</seed:tool_call>",
|
||||
};
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>");
|
||||
}
|
||||
|
||||
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str());
|
||||
|
||||
switch (builder.syntax().format) {
|
||||
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
|
||||
common_chat_parse_content_only(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GENERIC:
|
||||
common_chat_parse_generic(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
|
||||
common_chat_parse_mistral_nemo(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_MAGISTRAL:
|
||||
common_chat_parse_magistral(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X:
|
||||
common_chat_parse_llama_3_1(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
|
||||
common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
|
||||
common_chat_parse_deepseek_r1(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1:
|
||||
common_chat_parse_deepseek_v3_1(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
|
||||
common_chat_parse_functionary_v3_2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
|
||||
common_chat_parse_functionary_v3_1_llama_3_1(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
|
||||
common_chat_parse_hermes_2_pro(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
|
||||
common_chat_parse_firefunction_v2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B:
|
||||
common_chat_parse_command_r7b(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GRANITE:
|
||||
common_chat_parse_granite(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GPT_OSS:
|
||||
common_chat_parse_gpt_oss(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_SEED_OSS:
|
||||
common_chat_parse_seed_oss(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
|
||||
common_chat_parse_nemotron_v2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_APERTUS:
|
||||
common_chat_parse_apertus(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
|
||||
common_chat_parse_lfm2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_MINIMAX_M2:
|
||||
common_chat_parse_minimax_m2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GLM_4_5:
|
||||
common_chat_parse_glm_4_5(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_KIMI_K2:
|
||||
common_chat_parse_kimi_k2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
|
||||
common_chat_parse_qwen3_coder_xml(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5:
|
||||
common_chat_parse_apriel_1_5(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
|
||||
common_chat_parse_xiaomi_mimo(builder);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||
}
|
||||
builder.finish();
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
common_chat_msg_parser builder(input, is_partial, syntax);
|
||||
try {
|
||||
common_chat_parse(builder);
|
||||
} catch (const common_chat_msg_partial_exception & ex) {
|
||||
LOG_DBG("Partial parse: %s\n", ex.what());
|
||||
if (!is_partial) {
|
||||
builder.clear_tools();
|
||||
builder.move_to(0);
|
||||
common_chat_parse_content_only(builder);
|
||||
}
|
||||
}
|
||||
auto msg = builder.result();
|
||||
if (!is_partial) {
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
#include "chat-parser-xml-toolcall.h"
|
||||
#include "json-partial.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
@@ -119,5 +120,14 @@ class common_chat_msg_parser {
|
||||
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||
);
|
||||
|
||||
/**
|
||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||
*/
|
||||
bool try_consume_xml_tool_calls(const struct xml_tool_call_format & form);
|
||||
|
||||
// Parse content uses reasoning and XML-Style tool call
|
||||
void consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>");
|
||||
|
||||
void clear_tools();
|
||||
};
|
||||
|
||||
1214
common/chat.cpp
1214
common/chat.cpp
File diff suppressed because it is too large
Load Diff
@@ -117,6 +117,12 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||
COMMON_CHAT_FORMAT_APERTUS,
|
||||
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
|
||||
COMMON_CHAT_FORMAT_GLM_4_5,
|
||||
COMMON_CHAT_FORMAT_MINIMAX_M2,
|
||||
COMMON_CHAT_FORMAT_KIMI_K2,
|
||||
COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
|
||||
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
||||
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "sampling.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
@@ -26,7 +27,6 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
@@ -60,6 +60,14 @@
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||
|
||||
common_time_meas::~common_time_meas() {
|
||||
if (t_start_us >= 0) {
|
||||
t_acc += ggml_time_us() - t_start_us;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// CPU utils
|
||||
//
|
||||
@@ -355,11 +363,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||
}
|
||||
|
||||
void common_init() {
|
||||
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
||||
common_log_add(common_log_main(), level, "%s", text);
|
||||
}
|
||||
}, NULL);
|
||||
llama_log_set(common_log_default_callback, NULL);
|
||||
|
||||
#ifdef NDEBUG
|
||||
const char * build_type = "";
|
||||
@@ -908,7 +912,7 @@ std::string fs_get_cache_file(const std::string & filename) {
|
||||
return cache_directory + filename;
|
||||
}
|
||||
|
||||
std::vector<common_file_info> fs_list_files(const std::string & path) {
|
||||
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories) {
|
||||
std::vector<common_file_info> files;
|
||||
if (path.empty()) return files;
|
||||
|
||||
@@ -923,14 +927,22 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
|
||||
const auto & p = entry.path();
|
||||
if (std::filesystem::is_regular_file(p)) {
|
||||
common_file_info info;
|
||||
info.path = p.string();
|
||||
info.name = p.filename().string();
|
||||
info.path = p.string();
|
||||
info.name = p.filename().string();
|
||||
info.is_dir = false;
|
||||
try {
|
||||
info.size = static_cast<size_t>(std::filesystem::file_size(p));
|
||||
} catch (const std::filesystem::filesystem_error &) {
|
||||
info.size = 0;
|
||||
}
|
||||
files.push_back(std::move(info));
|
||||
} else if (include_directories && std::filesystem::is_directory(p)) {
|
||||
common_file_info info;
|
||||
info.path = p.string();
|
||||
info.name = p.filename().string();
|
||||
info.size = 0; // Directories have no size
|
||||
info.is_dir = true;
|
||||
files.push_back(std::move(info));
|
||||
}
|
||||
} catch (const std::filesystem::filesystem_error &) {
|
||||
// skip entries we cannot inspect
|
||||
@@ -946,6 +958,58 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
|
||||
// Model utils
|
||||
//
|
||||
|
||||
static inline void common_init_sampler_from_model(
|
||||
const llama_model * model,
|
||||
common_params_sampling & sparams) {
|
||||
|
||||
const uint64_t config = sparams.user_sampling_config;
|
||||
|
||||
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
|
||||
if (config & user_config) return;
|
||||
|
||||
char buf[64] = {0};
|
||||
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||
char * end = nullptr;
|
||||
int32_t v = strtol(buf, &end, 10);
|
||||
if (end && end != buf) dst = v;
|
||||
}
|
||||
};
|
||||
|
||||
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
|
||||
if (config & user_config) return;
|
||||
|
||||
char buf[128] = {0};
|
||||
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||
char * end = nullptr;
|
||||
float v = strtof(buf, &end);
|
||||
if (end && end != buf) dst = v;
|
||||
}
|
||||
};
|
||||
|
||||
// Sampling sequence
|
||||
if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) {
|
||||
char buf[512] = {0};
|
||||
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
|
||||
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
|
||||
if (!sampler_names.empty()) {
|
||||
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP);
|
||||
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT);
|
||||
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
|
||||
}
|
||||
|
||||
struct common_init_result common_init_from_params(common_params & params) {
|
||||
common_init_result iparams;
|
||||
auto mparams = common_model_params_to_llama(params);
|
||||
@@ -957,6 +1021,8 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
return iparams;
|
||||
}
|
||||
|
||||
common_init_sampler_from_model(model, params.sampling);
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
|
||||
@@ -2,17 +2,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ggml-opt.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <cmath>
|
||||
|
||||
#include "ggml-opt.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#define DIRECTORY_SEPARATOR '\\'
|
||||
@@ -28,7 +26,14 @@
|
||||
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
|
||||
} while(0)
|
||||
|
||||
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
||||
struct common_time_meas {
|
||||
common_time_meas(int64_t & t_acc, bool disable = false);
|
||||
~common_time_meas();
|
||||
|
||||
const int64_t t_start_us;
|
||||
|
||||
int64_t & t_acc;
|
||||
};
|
||||
|
||||
struct common_adapter_lora_info {
|
||||
std::string path;
|
||||
@@ -133,6 +138,22 @@ struct common_grammar_trigger {
|
||||
llama_token token = LLAMA_TOKEN_NULL;
|
||||
};
|
||||
|
||||
enum common_params_sampling_config : uint64_t {
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
|
||||
};
|
||||
|
||||
|
||||
// sampling parameters
|
||||
struct common_params_sampling {
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||
@@ -165,6 +186,8 @@ struct common_params_sampling {
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool timing_per_token = false;
|
||||
|
||||
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
|
||||
|
||||
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||
|
||||
|
||||
@@ -198,6 +221,7 @@ struct common_params_model {
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
std::string docker_repo = ""; // Docker repo // NOLINT
|
||||
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
|
||||
};
|
||||
|
||||
struct common_params_speculative {
|
||||
@@ -344,7 +368,7 @@ struct common_params {
|
||||
|
||||
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
|
||||
|
||||
int32_t verbosity = 0;
|
||||
int32_t verbosity = 3; // LOG_LEVEL_INFO
|
||||
int32_t control_vector_layer_start = -1; // layer range for control vector
|
||||
int32_t control_vector_layer_end = -1; // layer range for control vector
|
||||
bool offline = false;
|
||||
@@ -453,6 +477,11 @@ struct common_params {
|
||||
bool endpoint_props = false; // only control POST requests, not GET
|
||||
bool endpoint_metrics = false;
|
||||
|
||||
// router server configs
|
||||
std::string models_dir = ""; // directory containing models for the router server
|
||||
int models_max = 4; // maximum number of models to load simultaneously
|
||||
bool models_autoload = true; // automatically load models when requested via the router server
|
||||
|
||||
bool log_json = false;
|
||||
|
||||
std::string slot_save_path;
|
||||
@@ -616,8 +645,9 @@ struct common_file_info {
|
||||
std::string path;
|
||||
std::string name;
|
||||
size_t size = 0; // in bytes
|
||||
bool is_dir = false;
|
||||
};
|
||||
std::vector<common_file_info> fs_list_files(const std::string & path);
|
||||
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
|
||||
|
||||
//
|
||||
// Model utils
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
#if defined(LLAMA_USE_CURL)
|
||||
#include <curl/curl.h>
|
||||
#include <curl/easy.h>
|
||||
#else
|
||||
#elif defined(LLAMA_USE_HTTPLIB)
|
||||
#include "http.h"
|
||||
#endif
|
||||
|
||||
@@ -430,7 +430,7 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
|
||||
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
|
||||
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
|
||||
curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 1L);
|
||||
curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 0L);
|
||||
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
|
||||
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
|
||||
auto data_vec = static_cast<std::vector<char> *>(data);
|
||||
@@ -467,7 +467,7 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
|
||||
return { res_code, std::move(res_buffer) };
|
||||
}
|
||||
|
||||
#else
|
||||
#elif defined(LLAMA_USE_HTTPLIB)
|
||||
|
||||
static bool is_output_a_tty() {
|
||||
#if defined(_WIN32)
|
||||
@@ -517,16 +517,18 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
|
||||
}
|
||||
|
||||
std::atomic<size_t> downloaded{existing_size};
|
||||
const char * func = __func__; // avoid __func__ inside a lambda
|
||||
size_t downloaded = existing_size;
|
||||
size_t progress_step = 0;
|
||||
|
||||
auto res = cli.Get(resolve_path, headers,
|
||||
[&](const httplib::Response &response) {
|
||||
if (existing_size > 0 && response.status != 206) {
|
||||
LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", __func__, response.status);
|
||||
LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status);
|
||||
return false;
|
||||
}
|
||||
if (existing_size == 0 && response.status != 200) {
|
||||
LOG_WRN("%s: download received non-successful status code: %d\n", __func__, response.status);
|
||||
LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status);
|
||||
return false;
|
||||
}
|
||||
if (total_size == 0 && response.has_header("Content-Length")) {
|
||||
@@ -534,7 +536,7 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
size_t content_length = std::stoull(response.get_header_value("Content-Length"));
|
||||
total_size = existing_size + content_length;
|
||||
} catch (const std::exception &e) {
|
||||
LOG_WRN("%s: invalid Content-Length header: %s\n", __func__, e.what());
|
||||
LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what());
|
||||
}
|
||||
}
|
||||
return true;
|
||||
@@ -542,11 +544,16 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
[&](const char *data, size_t len) {
|
||||
ofs.write(data, len);
|
||||
if (!ofs) {
|
||||
LOG_ERR("%s: error writing to file: %s\n", __func__, path_tmp.c_str());
|
||||
LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str());
|
||||
return false;
|
||||
}
|
||||
downloaded += len;
|
||||
print_progress(downloaded, total_size);
|
||||
progress_step += len;
|
||||
|
||||
if (progress_step >= total_size / 1000 || downloaded == total_size) {
|
||||
print_progress(downloaded, total_size);
|
||||
progress_step = 0;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
nullptr
|
||||
@@ -713,6 +720,8 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
|
||||
|
||||
#endif // LLAMA_USE_CURL
|
||||
|
||||
#if defined(LLAMA_USE_CURL) || defined(LLAMA_USE_HTTPLIB)
|
||||
|
||||
static bool common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
@@ -907,33 +916,6 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
|
||||
return { hf_repo, ggufFile, mmprojFile };
|
||||
}
|
||||
|
||||
std::vector<common_cached_model_info> common_list_cached_models() {
|
||||
std::vector<common_cached_model_info> models;
|
||||
const std::string cache_dir = fs_get_cache_directory();
|
||||
const std::vector<common_file_info> files = fs_list_files(cache_dir);
|
||||
for (const auto & file : files) {
|
||||
if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
|
||||
common_cached_model_info model_info;
|
||||
model_info.manifest_path = file.path;
|
||||
std::string fname = file.name;
|
||||
string_replace_all(fname, ".json", ""); // remove extension
|
||||
auto parts = string_split<std::string>(fname, '=');
|
||||
if (parts.size() == 4) {
|
||||
// expect format: manifest=<user>=<model>=<tag>=<other>
|
||||
model_info.user = parts[1];
|
||||
model_info.model = parts[2];
|
||||
model_info.tag = parts[3];
|
||||
} else {
|
||||
// invalid format
|
||||
continue;
|
||||
}
|
||||
model_info.size = 0; // TODO: get GGUF size, not manifest size
|
||||
models.push_back(model_info);
|
||||
}
|
||||
}
|
||||
return models;
|
||||
}
|
||||
|
||||
//
|
||||
// Docker registry functions
|
||||
//
|
||||
@@ -1052,3 +1034,46 @@ std::string common_docker_resolve_model(const std::string & docker) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
bool common_download_model(const common_params_model &, const std::string &, bool) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
std::string common_docker_resolve_model(const std::string &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
#endif // LLAMA_USE_CURL || LLAMA_USE_HTTPLIB
|
||||
|
||||
std::vector<common_cached_model_info> common_list_cached_models() {
|
||||
std::vector<common_cached_model_info> models;
|
||||
const std::string cache_dir = fs_get_cache_directory();
|
||||
const std::vector<common_file_info> files = fs_list(cache_dir, false);
|
||||
for (const auto & file : files) {
|
||||
if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
|
||||
common_cached_model_info model_info;
|
||||
model_info.manifest_path = file.path;
|
||||
std::string fname = file.name;
|
||||
string_replace_all(fname, ".json", ""); // remove extension
|
||||
auto parts = string_split<std::string>(fname, '=');
|
||||
if (parts.size() == 4) {
|
||||
// expect format: manifest=<user>=<model>=<tag>=<other>
|
||||
model_info.user = parts[1];
|
||||
model_info.model = parts[2];
|
||||
model_info.tag = parts[3];
|
||||
} else {
|
||||
// invalid format
|
||||
continue;
|
||||
}
|
||||
model_info.size = 0; // TODO: get GGUF size, not manifest size
|
||||
models.push_back(model_info);
|
||||
}
|
||||
}
|
||||
return models;
|
||||
}
|
||||
|
||||
@@ -14,8 +14,10 @@ struct common_cached_model_info {
|
||||
std::string model;
|
||||
std::string tag;
|
||||
size_t size = 0; // GGUF size in bytes
|
||||
// return string representation like "user/model:tag"
|
||||
// if tag is "latest", it will be omitted
|
||||
std::string to_string() const {
|
||||
return user + "/" + model + ":" + tag;
|
||||
return user + "/" + model + (tag == "latest" ? "" : ":" + tag);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -297,8 +297,25 @@ bool common_json_parse(
|
||||
it = temptative_end;
|
||||
return true;
|
||||
}
|
||||
// TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||
// fprintf(stderr, "Closing: TODO\n");
|
||||
// handle unclosed top-level primitive
|
||||
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
|
||||
std::string str(it, temptative_end);
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;
|
||||
if (can_parse(str + "\"")) {
|
||||
// Was inside an string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
|
||||
// Was inside an string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
|
||||
} else {
|
||||
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||
// fprintf(stderr, "Closing: TODO\n");
|
||||
return false;
|
||||
}
|
||||
out.json = json::parse(str);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
out.json = json::parse(it, end);
|
||||
|
||||
@@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) {
|
||||
}
|
||||
|
||||
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
|
||||
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
|
||||
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
|
||||
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
|
||||
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
|
||||
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
|
||||
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"}
|
||||
};
|
||||
|
||||
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
||||
@@ -303,6 +303,8 @@ static std::string format_literal(const std::string & literal) {
|
||||
return "\"" + escaped + "\"";
|
||||
}
|
||||
|
||||
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
|
||||
|
||||
class SchemaConverter {
|
||||
private:
|
||||
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
||||
|
||||
@@ -18,4 +18,6 @@ struct common_grammar_options {
|
||||
bool dotall = false;
|
||||
};
|
||||
|
||||
std::string gbnf_format_literal(const std::string & literal);
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||
|
||||
@@ -442,3 +442,23 @@ void common_log_set_prefix(struct common_log * log, bool prefix) {
|
||||
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
||||
log->set_timestamps(timestamps);
|
||||
}
|
||||
|
||||
static int common_get_verbosity(enum ggml_log_level level) {
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG;
|
||||
case GGML_LOG_LEVEL_INFO: return LOG_LEVEL_INFO;
|
||||
case GGML_LOG_LEVEL_WARN: return LOG_LEVEL_WARN;
|
||||
case GGML_LOG_LEVEL_ERROR: return LOG_LEVEL_ERROR;
|
||||
case GGML_LOG_LEVEL_CONT: return LOG_LEVEL_INFO; // same as INFO
|
||||
case GGML_LOG_LEVEL_NONE:
|
||||
default:
|
||||
return LOG_LEVEL_OUTPUT;
|
||||
}
|
||||
}
|
||||
|
||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||
auto verbosity = common_get_verbosity(level);
|
||||
if (verbosity <= common_log_verbosity_thold) {
|
||||
common_log_add(common_log_main(), level, "%s", text);
|
||||
}
|
||||
}
|
||||
|
||||
33
common/log.h
33
common/log.h
@@ -21,8 +21,14 @@
|
||||
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
#endif
|
||||
|
||||
#define LOG_DEFAULT_DEBUG 1
|
||||
#define LOG_DEFAULT_LLAMA 0
|
||||
#define LOG_LEVEL_DEBUG 4
|
||||
#define LOG_LEVEL_INFO 3
|
||||
#define LOG_LEVEL_WARN 2
|
||||
#define LOG_LEVEL_ERROR 1
|
||||
#define LOG_LEVEL_OUTPUT 0 // output data from tools
|
||||
|
||||
#define LOG_DEFAULT_DEBUG LOG_LEVEL_DEBUG
|
||||
#define LOG_DEFAULT_LLAMA LOG_LEVEL_INFO
|
||||
|
||||
enum log_colors {
|
||||
LOG_COLORS_AUTO = -1,
|
||||
@@ -36,6 +42,8 @@ extern int common_log_verbosity_thold;
|
||||
|
||||
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
||||
|
||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
|
||||
|
||||
// the common_log uses an internal worker thread to print/write log messages
|
||||
// when the worker thread is paused, incoming log messages are discarded
|
||||
struct common_log;
|
||||
@@ -65,10 +73,11 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch
|
||||
// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU
|
||||
// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU
|
||||
//
|
||||
// I - info (stdout, V = 0)
|
||||
// W - warning (stderr, V = 0)
|
||||
// E - error (stderr, V = 0)
|
||||
// D - debug (stderr, V = LOG_DEFAULT_DEBUG)
|
||||
// I - info (stdout, V = LOG_DEFAULT_INFO)
|
||||
// W - warning (stderr, V = LOG_DEFAULT_WARN)
|
||||
// E - error (stderr, V = LOG_DEFAULT_ERROR)
|
||||
// O - output (stdout, V = LOG_DEFAULT_OUTPUT)
|
||||
//
|
||||
|
||||
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
|
||||
@@ -93,14 +102,14 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps); // w
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, 0, __VA_ARGS__)
|
||||
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
|
||||
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, LOG_LEVEL_OUTPUT, __VA_ARGS__)
|
||||
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
|
||||
|
||||
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, 0, __VA_ARGS__)
|
||||
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__)
|
||||
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__)
|
||||
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__)
|
||||
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__)
|
||||
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, LOG_LEVEL_INFO, __VA_ARGS__)
|
||||
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, LOG_LEVEL_WARN, __VA_ARGS__)
|
||||
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, LOG_LEVEL_INFO, __VA_ARGS__) // same as INFO
|
||||
|
||||
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
|
||||
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
|
||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||
// TODO: deduplicate with llama-impl.h
|
||||
@@ -112,6 +113,13 @@ struct common_sampler {
|
||||
|
||||
llama_token_data_array cur_p;
|
||||
|
||||
void reset() {
|
||||
prev.clear();
|
||||
|
||||
llama_sampler_reset(grmr);
|
||||
llama_sampler_reset(chain);
|
||||
}
|
||||
|
||||
void set_logits(struct llama_context * ctx, int idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
|
||||
@@ -128,6 +136,12 @@ struct common_sampler {
|
||||
|
||||
cur_p = { cur.data(), cur.size(), -1, false };
|
||||
}
|
||||
|
||||
common_time_meas tm() {
|
||||
return common_time_meas(t_total_us, params.no_perf);
|
||||
}
|
||||
|
||||
mutable int64_t t_total_us = 0;
|
||||
};
|
||||
|
||||
std::string common_params_sampling::print() const {
|
||||
@@ -298,6 +312,8 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
||||
}
|
||||
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
if (accept_grammar) {
|
||||
llama_sampler_accept(gsmpl->grmr, token);
|
||||
}
|
||||
@@ -308,9 +324,7 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
|
||||
}
|
||||
|
||||
void common_sampler_reset(struct common_sampler * gsmpl) {
|
||||
llama_sampler_reset(gsmpl->grmr);
|
||||
|
||||
llama_sampler_reset(gsmpl->chain);
|
||||
gsmpl->reset();
|
||||
}
|
||||
|
||||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||
@@ -327,16 +341,54 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
||||
// TODO: measure grammar performance
|
||||
|
||||
const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0;
|
||||
|
||||
llama_perf_sampler_data data_smpl;
|
||||
llama_perf_context_data data_ctx;
|
||||
|
||||
memset(&data_smpl, 0, sizeof(data_smpl));
|
||||
memset(&data_ctx, 0, sizeof(data_ctx));
|
||||
|
||||
if (gsmpl) {
|
||||
llama_perf_sampler_print(gsmpl->chain);
|
||||
auto & data = data_smpl;
|
||||
|
||||
data = llama_perf_sampler(gsmpl->chain);
|
||||
|
||||
// note: the sampling time includes the samplers time + extra time spent in common/sampling
|
||||
LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms);
|
||||
LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample);
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
llama_perf_context_print(ctx);
|
||||
auto & data = data_ctx;
|
||||
|
||||
data = llama_perf_context(ctx);
|
||||
|
||||
const double t_end_ms = 1e-3 * ggml_time_us();
|
||||
|
||||
const double t_total_ms = t_end_ms - data.t_start_ms;
|
||||
const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms);
|
||||
const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms;
|
||||
|
||||
LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
|
||||
LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
|
||||
LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
||||
LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
||||
LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc);
|
||||
LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused);
|
||||
|
||||
llama_memory_breakdown_print(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
auto & grmr = gsmpl->grmr;
|
||||
@@ -428,6 +480,8 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||
// helpers
|
||||
|
||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
auto * res = &gsmpl->cur_p;
|
||||
|
||||
if (do_sort && !res->sorted) {
|
||||
|
||||
@@ -189,10 +189,10 @@ class ModelBase:
|
||||
return tensors
|
||||
|
||||
prefix = "model" if not self.is_mistral_format else "consolidated"
|
||||
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
|
||||
part_names: set[str] = set(ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors"))
|
||||
is_safetensors: bool = len(part_names) > 0
|
||||
if not is_safetensors:
|
||||
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
part_names = set(ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin"))
|
||||
|
||||
tensor_names_from_index: set[str] = set()
|
||||
|
||||
@@ -209,6 +209,7 @@ class ModelBase:
|
||||
if weight_map is None or not isinstance(weight_map, dict):
|
||||
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
|
||||
tensor_names_from_index.update(weight_map.keys())
|
||||
part_names |= set(weight_map.values())
|
||||
else:
|
||||
weight_map = {}
|
||||
else:
|
||||
@@ -564,7 +565,7 @@ class ModelBase:
|
||||
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
|
||||
)
|
||||
)
|
||||
or not new_name.endswith(".weight")
|
||||
or new_name[-7:] not in (".weight", ".lora_a", ".lora_b")
|
||||
):
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
|
||||
@@ -825,6 +826,15 @@ class TextModel(ModelBase):
|
||||
self.gguf_writer.add_expert_group_used_count(n_group_used)
|
||||
logger.info(f"gguf: expert groups used count = {n_group_used}")
|
||||
|
||||
if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func"], optional=True)) is not None:
|
||||
if score_func == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif score_func == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported expert score gating function value: {score_func}")
|
||||
logger.info(f"gguf: expert score gating function = {score_func}")
|
||||
|
||||
if (head_dim := self.hparams.get("head_dim")) is not None:
|
||||
self.gguf_writer.add_key_length(head_dim)
|
||||
self.gguf_writer.add_value_length(head_dim)
|
||||
@@ -1124,6 +1134,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
|
||||
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
|
||||
res = "mellum"
|
||||
if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
|
||||
# ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
|
||||
res = "afmoe"
|
||||
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
|
||||
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
|
||||
res = "bailingmoe2"
|
||||
@@ -1568,10 +1581,27 @@ class MmprojModel(ModelBase):
|
||||
|
||||
# load preprocessor config
|
||||
self.preprocessor_config = {}
|
||||
if not self.is_mistral_format:
|
||||
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
|
||||
|
||||
# prefer preprocessor_config.json if possible
|
||||
preprocessor_config_path = self.dir_model / "preprocessor_config.json"
|
||||
if preprocessor_config_path.is_file():
|
||||
with open(preprocessor_config_path, "r", encoding="utf-8") as f:
|
||||
self.preprocessor_config = json.load(f)
|
||||
|
||||
# prefer processor_config.json if possible
|
||||
processor_config_path = self.dir_model / "processor_config.json"
|
||||
if processor_config_path.is_file():
|
||||
with open(processor_config_path, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
# move image_processor to root level for compat
|
||||
if "image_processor" in cfg:
|
||||
cfg = {
|
||||
**cfg,
|
||||
**cfg["image_processor"],
|
||||
}
|
||||
# merge configs
|
||||
self.preprocessor_config = {**self.preprocessor_config, **cfg}
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
|
||||
return self.global_config.get(config_name)
|
||||
@@ -1660,11 +1690,9 @@ class GPTNeoXModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GPTNEOX
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(
|
||||
int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
|
||||
@@ -1722,7 +1750,7 @@ class BloomModel(TextModel):
|
||||
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
|
||||
self.gguf_writer.add_embedding_length(n_embed)
|
||||
self.gguf_writer.add_feed_forward_length(4 * n_embed)
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head)
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
@@ -1785,10 +1813,9 @@ class MPTModel(TextModel):
|
||||
self.gguf_writer.add_unk_token_id(0)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["n_layers"]
|
||||
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
|
||||
self.gguf_writer.add_head_count(self.hparams["n_heads"])
|
||||
if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
|
||||
@@ -1821,7 +1848,6 @@ class OrionModel(TextModel):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
|
||||
@@ -1839,7 +1865,7 @@ class OrionModel(TextModel):
|
||||
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
||||
self.gguf_writer.add_context_length(ctx_length)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||
@@ -1856,7 +1882,6 @@ class BaichuanModel(TextModel):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
|
||||
@@ -1873,7 +1898,7 @@ class BaichuanModel(TextModel):
|
||||
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
||||
self.gguf_writer.add_context_length(ctx_length)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
@@ -1980,7 +2005,6 @@ class XverseModel(TextModel):
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
|
||||
@@ -1997,7 +2021,7 @@ class XverseModel(TextModel):
|
||||
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
||||
self.gguf_writer.add_context_length(ctx_length)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
@@ -2040,10 +2064,6 @@ class FalconModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.FALCON
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams.get("num_hidden_layers")
|
||||
if block_count is None:
|
||||
block_count = self.hparams["n_layer"] # old name
|
||||
|
||||
n_head = self.hparams.get("num_attention_heads")
|
||||
if n_head is None:
|
||||
n_head = self.hparams["n_head"] # old name
|
||||
@@ -2056,7 +2076,7 @@ class FalconModel(TextModel):
|
||||
self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
@@ -2094,12 +2114,10 @@ class StarCoderModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.STARCODER
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["n_layer"]
|
||||
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||
self.gguf_writer.add_head_count_kv(1)
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
@@ -2129,14 +2147,12 @@ class RefactModel(TextModel):
|
||||
multiple_of = 256
|
||||
ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
block_count = self.hparams["n_layer"]
|
||||
|
||||
# refact uses Alibi. So this is from config.json which might be used by training.
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
|
||||
self.gguf_writer.add_feed_forward_length(ff_dim)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||
self.gguf_writer.add_head_count_kv(1)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
|
||||
@@ -2183,11 +2199,10 @@ class StableLMModel(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
|
||||
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
|
||||
@@ -2533,6 +2548,72 @@ class ArceeModel(LlamaModel):
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||
|
||||
|
||||
@ModelBase.register("AfmoeForCausalLM")
|
||||
class AfmoeModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.AFMOE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# MoE parameters
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
|
||||
self.gguf_writer.add_expert_shared_count(n_shared_experts)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
if (n_dense_layers := self.hparams.get("num_dense_layers")) is not None:
|
||||
self.gguf_writer.add_leading_dense_block_count(n_dense_layers)
|
||||
|
||||
# Route normalization and scaling
|
||||
if (route_norm := self.hparams.get("route_norm")) is not None:
|
||||
self.gguf_writer.add_expert_weights_norm(route_norm)
|
||||
if (route_scale := self.hparams.get("route_scale")) is not None:
|
||||
self.gguf_writer.add_expert_weights_scale(route_scale)
|
||||
|
||||
# Sliding window attention
|
||||
if (sliding_window := self.hparams.get("sliding_window")) is not None:
|
||||
self.gguf_writer.add_sliding_window(sliding_window)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Handle expert weights - they're already merged in the HF format
|
||||
# process the experts separately
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["gate_proj", "up_proj", "down_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename_to_retrieve])
|
||||
del self._experts[bid][ename_to_retrieve]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
if name.endswith(".expert_bias"):
|
||||
name = name.replace(".expert_bias", ".expert_bias.bias")
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register(
|
||||
"LlavaForConditionalGeneration", # pixtral
|
||||
"Mistral3ForConditionalGeneration", # mistral small 3.1
|
||||
@@ -2733,7 +2814,32 @@ class Llama4VisionModel(MmprojModel):
|
||||
|
||||
@ModelBase.register("Mistral3ForConditionalGeneration")
|
||||
class Mistral3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
model_arch = gguf.MODEL_ARCH.MISTRAL3
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# for compatibility, we use LLAMA arch for older models
|
||||
# TODO: remove this once everyone has migrated to newer version of llama.cpp
|
||||
if self.hparams.get("model_type") != "ministral3":
|
||||
self.model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
|
||||
self.gguf_writer.add_architecture()
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
rope_params = self.hparams.get("rope_parameters")
|
||||
if self.hparams.get("model_type") == "ministral3":
|
||||
assert rope_params is not None, "ministral3 must have 'rope_parameters' config"
|
||||
assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_params["factor"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_params["beta_fast"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_params["beta_slow"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_params["original_max_position_embeddings"])
|
||||
self.gguf_writer.add_rope_freq_base(rope_params["rope_theta"])
|
||||
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
name = name.replace("language_model.", "")
|
||||
@@ -3072,7 +3178,7 @@ class DbrxModel(TextModel):
|
||||
def set_gguf_parameters(self):
|
||||
ffn_config = self.hparams["ffn_config"]
|
||||
attn_config = self.hparams["attn_config"]
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
|
||||
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
@@ -3274,7 +3380,7 @@ class QwenModel(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
|
||||
@@ -4119,6 +4225,51 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
||||
super().set_vocab()
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3NextForCausalLM")
|
||||
class Qwen3NextModel(Qwen2MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3NEXT
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_ssm_conv_kernel(self.hparams["linear_conv_kernel_dim"])
|
||||
self.gguf_writer.add_ssm_state_size(self.hparams["linear_key_head_dim"])
|
||||
self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"])
|
||||
self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"])
|
||||
self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"])
|
||||
if (rope_dim := self.hparams.get("head_dim")) is None:
|
||||
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.startswith("mtp"):
|
||||
return [] # ignore MTP layers for now
|
||||
if name.endswith(".A_log"):
|
||||
data_torch = -torch.exp(data_torch)
|
||||
elif name.endswith(".dt_bias"):
|
||||
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
|
||||
elif "conv1d" in name:
|
||||
data_torch = data_torch.squeeze()
|
||||
elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"):
|
||||
data_torch = data_torch + 1
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("RND1")
|
||||
class RND1Model(Qwen2MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.RND1
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# RND1 specific parameters
|
||||
# RND1 uses bidirectional attention
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
|
||||
if (mask_token_id := self.hparams.get("mask_token_id")) is not None:
|
||||
self.gguf_writer.add_mask_token_id(mask_token_id)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
|
||||
class Qwen3VLVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -4305,7 +4456,7 @@ class GPT2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GPT2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_context_length(self.hparams["n_ctx"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||
@@ -4337,8 +4488,6 @@ class Phi2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.PHI2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
|
||||
|
||||
rot_pct = self.find_hparam(["partial_rotary_factor"])
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
@@ -4347,7 +4496,7 @@ class Phi2Model(TextModel):
|
||||
|
||||
self.gguf_writer.add_embedding_length(n_embd)
|
||||
self.gguf_writer.add_feed_forward_length(4 * n_embd)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head)
|
||||
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"]))
|
||||
@@ -4465,8 +4614,6 @@ class Phi3MiniModel(TextModel):
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
|
||||
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
|
||||
@@ -4480,7 +4627,7 @@ class Phi3MiniModel(TextModel):
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
|
||||
self.gguf_writer.add_embedding_length(n_embd)
|
||||
self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
|
||||
@@ -4600,12 +4747,11 @@ class PlamoModel(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(4096) # not in config.json
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
|
||||
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
||||
@@ -4728,7 +4874,6 @@ class Plamo2Model(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
|
||||
# Which layers are Mamba layers
|
||||
@@ -4740,10 +4885,10 @@ class Plamo2Model(TextModel):
|
||||
num_attention_heads = []
|
||||
|
||||
if mamba_enabled:
|
||||
for i in range(block_count):
|
||||
if block_count <= (mamba_step // 2):
|
||||
for i in range(self.block_count):
|
||||
if self.block_count <= (mamba_step // 2):
|
||||
# use attention in last layer
|
||||
is_mamba = (i != block_count - 1)
|
||||
is_mamba = (i != self.block_count - 1)
|
||||
else:
|
||||
is_mamba = (i % mamba_step) != (mamba_step // 2)
|
||||
if is_mamba:
|
||||
@@ -4761,7 +4906,7 @@ class Plamo2Model(TextModel):
|
||||
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
|
||||
self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128))
|
||||
self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128))
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
|
||||
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
|
||||
|
||||
@@ -4818,12 +4963,10 @@ class CodeShellModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.CODESHELL
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["n_layer"]
|
||||
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
@@ -4965,7 +5108,7 @@ class InternLM2Model(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
|
||||
@@ -5586,11 +5729,10 @@ class GemmaModel(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
|
||||
@@ -5626,11 +5768,10 @@ class Gemma2Model(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
|
||||
@@ -5674,12 +5815,11 @@ class Gemma3Model(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
# some default values are not specified in the hparams
|
||||
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
|
||||
@@ -5955,7 +6095,6 @@ class Rwkv6Model(TextModel):
|
||||
self._set_vocab_rwkv_world()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_size = self.hparams["head_size"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
||||
@@ -5967,7 +6106,7 @@ class Rwkv6Model(TextModel):
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
@@ -6031,7 +6170,6 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
num_attention_heads = self.hparams["num_attention_heads"]
|
||||
num_key_value_heads = self.hparams["num_key_value_heads"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
@@ -6044,7 +6182,7 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
|
||||
self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
|
||||
@@ -6085,7 +6223,6 @@ class Rwkv7Model(TextModel):
|
||||
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
try:
|
||||
head_size = self.hparams["head_size"]
|
||||
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
||||
@@ -6110,7 +6247,7 @@ class Rwkv7Model(TextModel):
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||
@@ -6204,7 +6341,6 @@ class ARwkv7Model(Rwkv7Model):
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
head_size = self.hparams["head_size"]
|
||||
rms_norm_eps = self.hparams["rms_norm_eps"]
|
||||
@@ -6221,7 +6357,7 @@ class ARwkv7Model(Rwkv7Model):
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||
@@ -7104,13 +7240,6 @@ class DeepseekV2Model(TextModel):
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
if hparams["scoring_func"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif hparams["scoring_func"] == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||
@@ -7216,12 +7345,6 @@ class MiniMaxM2Model(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if self.hparams["scoring_func"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif self.hparams["scoring_func"] == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
|
||||
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||
self.gguf_writer.add_rope_dimension_count(self.find_hparam(["rotary_dim"]))
|
||||
@@ -7314,11 +7437,6 @@ class Dots1Model(Qwen2MoeModel):
|
||||
self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
|
||||
|
||||
if self.hparams["scoring_func"] == "noaux_tc":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
if name.endswith("e_score_correction_bias"):
|
||||
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
|
||||
@@ -7463,7 +7581,7 @@ class T5Model(TextModel):
|
||||
self.gguf_writer.add_context_length(n_ctx)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None:
|
||||
self.gguf_writer.add_decoder_block_count(dec_n_layer)
|
||||
self.gguf_writer.add_head_count(self.hparams["num_heads"])
|
||||
@@ -7602,7 +7720,7 @@ class T5EncoderModel(TextModel):
|
||||
self.gguf_writer.add_context_length(n_ctx)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["num_heads"])
|
||||
self.gguf_writer.add_key_length(self.hparams["d_kv"])
|
||||
self.gguf_writer.add_value_length(self.hparams["d_kv"])
|
||||
@@ -7665,7 +7783,7 @@ class JaisModel(TextModel):
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"])
|
||||
@@ -7779,12 +7897,6 @@ class Glm4MoeModel(TextModel):
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
|
||||
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
|
||||
|
||||
# Patch broken chat template
|
||||
if isinstance(special_vocab.chat_template, str) and "visible_text(m.content).endswith" in special_vocab.chat_template:
|
||||
special_vocab.chat_template = special_vocab.chat_template.replace(
|
||||
"""{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""",
|
||||
"""{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""")
|
||||
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
@@ -8013,7 +8125,7 @@ class ChatGLMModel(TextModel):
|
||||
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
|
||||
self.gguf_writer.add_embedding_length(n_embed)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed)))
|
||||
self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"]))
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5))
|
||||
@@ -8095,7 +8207,6 @@ class ExaoneModel(TextModel):
|
||||
num_kv_heads = hparams.get("num_key_value_heads", num_heads)
|
||||
layer_norm_eps = hparams["layer_norm_epsilon"]
|
||||
intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
|
||||
num_layers = hparams["num_layers"]
|
||||
# ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
|
||||
# attention_dropout_rate = hparams["attention_dropout"]
|
||||
# ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
|
||||
@@ -8106,7 +8217,7 @@ class ExaoneModel(TextModel):
|
||||
self.gguf_writer.add_context_length(max_position_embeddings)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||
self.gguf_writer.add_block_count(num_layers)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
if (rope_theta := self.hparams.get("rope_theta")) is not None:
|
||||
@@ -8639,13 +8750,6 @@ class BailingMoeV2Model(TextModel):
|
||||
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
if hparams["score_function"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif hparams["score_function"] == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")
|
||||
|
||||
if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
|
||||
self.gguf_writer.add_nextn_predict_layers(nextn_layers)
|
||||
|
||||
@@ -9341,16 +9445,6 @@ class HunYuanModel(TextModel):
|
||||
class SmolLM3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.SMOLLM3
|
||||
|
||||
def set_vocab(self):
|
||||
super().set_vocab()
|
||||
# remove unsupported array slicing in chat template
|
||||
# ref: https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/discussions/1
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
||||
if tokenizer.chat_template is not None:
|
||||
chat_template = tokenizer.chat_template.replace("[:]", "")
|
||||
self.gguf_writer.add_chat_template(chat_template)
|
||||
|
||||
|
||||
@ModelBase.register("GptOssForCausalLM")
|
||||
class GptOssModel(TextModel):
|
||||
@@ -9757,12 +9851,22 @@ class ApertusModel(LlamaModel):
|
||||
|
||||
|
||||
class MistralModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
model_arch = gguf.MODEL_ARCH.MISTRAL3
|
||||
model_name = "Mistral"
|
||||
hf_arch = ""
|
||||
is_mistral_format = True
|
||||
undo_permute = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# for compatibility, we use LLAMA arch for older models
|
||||
# TODO: remove this once everyone migrates to newer version of llama.cpp
|
||||
if "llama_4_scaling" not in self.hparams:
|
||||
self.model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
|
||||
self.gguf_writer.add_architecture()
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
@staticmethod
|
||||
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
|
||||
assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg
|
||||
@@ -9802,6 +9906,20 @@ class MistralModel(LlamaModel):
|
||||
|
||||
return template
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if "yarn" in self.hparams:
|
||||
yarn_params = self.hparams["yarn"]
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
|
||||
|
||||
if "llama_4_scaling" in self.hparams:
|
||||
self.gguf_writer.add_attn_temperature_scale(self.hparams["llama_4_scaling"]["beta"])
|
||||
|
||||
|
||||
class PixtralModel(LlavaVisionModel):
|
||||
model_name = "Pixtral"
|
||||
@@ -10039,6 +10157,25 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||
torch.uint8: np.uint8,
|
||||
}
|
||||
|
||||
# only used when byteswapping data. Only correct size is needed
|
||||
_dtype_byteswap_map: dict[torch.dtype, type] = {
|
||||
torch.float64: np.float64,
|
||||
torch.float32: np.float32,
|
||||
torch.bfloat16: np.float16,
|
||||
torch.float16: np.float16,
|
||||
torch.int64: np.int64,
|
||||
torch.uint64: np.uint64,
|
||||
torch.int32: np.int32,
|
||||
torch.uint32: np.uint32,
|
||||
torch.int16: np.int16,
|
||||
torch.uint16: np.uint16,
|
||||
torch.int8: np.int8,
|
||||
torch.uint8: np.uint8,
|
||||
torch.bool: np.uint8,
|
||||
torch.float8_e4m3fn: np.uint8,
|
||||
torch.float8_e5m2: np.uint8,
|
||||
}
|
||||
|
||||
# used for safetensors slices
|
||||
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
|
||||
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
|
||||
@@ -10082,8 +10219,14 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||
@classmethod
|
||||
def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
|
||||
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
|
||||
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
|
||||
if sys.byteorder == 'big':
|
||||
# switch data back to big endian
|
||||
tensor = tensor.view(dtype).byteswap(inplace=False)
|
||||
return tensor
|
||||
dtype = cls._dtype_str_map[tensor.dtype]
|
||||
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
|
||||
numpy_dtype = cls._dtype_byteswap_map[dtype]
|
||||
return torch.from_numpy(byteswap_tensor(tensor.mmap_bytes(), numpy_dtype)).view(dtype).reshape(tensor.shape)
|
||||
dtype = cls._dtype_str_map[t.dtype]
|
||||
shape = t.shape
|
||||
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
|
||||
@@ -10091,10 +10234,16 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||
|
||||
@classmethod
|
||||
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
|
||||
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
|
||||
if sys.byteorder == 'big':
|
||||
# switch data back to big endian
|
||||
tensor = tensor.view(dtype).byteswap(inplace=False)
|
||||
return tensor
|
||||
dtype = cls._dtype_str_map[remote_tensor.dtype]
|
||||
numpy_dtype = cls._dtype_byteswap_map[dtype]
|
||||
shape = remote_tensor.shape
|
||||
meta = cls.meta_with_dtype_and_shape(dtype, shape)
|
||||
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape))
|
||||
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.from_numpy(byteswap_tensor(np.frombuffer(r.data(), dtype=numpy_dtype), numpy_dtype)).view(dtype).reshape(shape))
|
||||
return cast(torch.Tensor, lazy)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -139,6 +139,7 @@ models = [
|
||||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||
{"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
|
||||
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
|
||||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
|
||||
|
||||
@@ -242,7 +242,7 @@ def parse_args() -> argparse.Namespace:
|
||||
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f32",
|
||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -277,10 +277,15 @@ def parse_args() -> argparse.Namespace:
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
|
||||
def load_hparams_from_hf(hf_model_id: str) -> tuple[dict[str, Any], Path | None]:
|
||||
from huggingface_hub import try_to_load_from_cache
|
||||
|
||||
# normally, adapter does not come with base model config, we need to load it from AutoConfig
|
||||
config = AutoConfig.from_pretrained(hf_model_id)
|
||||
return config.to_dict()
|
||||
cache_dir = try_to_load_from_cache(hf_model_id, "config.json")
|
||||
cache_dir = Path(cache_dir).parent if isinstance(cache_dir, str) else None
|
||||
|
||||
return config.to_dict(), cache_dir
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -325,13 +330,13 @@ if __name__ == '__main__':
|
||||
# load base model
|
||||
if base_model_id is not None:
|
||||
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
|
||||
hparams = load_hparams_from_hf(base_model_id)
|
||||
hparams, dir_base_model = load_hparams_from_hf(base_model_id)
|
||||
elif dir_base_model is None:
|
||||
if "base_model_name_or_path" in lparams:
|
||||
model_id = lparams["base_model_name_or_path"]
|
||||
logger.info(f"Loading base model from Hugging Face: {model_id}")
|
||||
try:
|
||||
hparams = load_hparams_from_hf(model_id)
|
||||
hparams, dir_base_model = load_hparams_from_hf(model_id)
|
||||
except OSError as e:
|
||||
logger.error(f"Failed to load base model config: {e}")
|
||||
logger.error("Please try downloading the base model and add its path to --base")
|
||||
@@ -480,6 +485,7 @@ if __name__ == '__main__':
|
||||
dir_lora_model=dir_lora,
|
||||
lora_alpha=alpha,
|
||||
hparams=hparams,
|
||||
remote_hf_model_id=base_model_id,
|
||||
)
|
||||
|
||||
logger.info("Exporting model...")
|
||||
|
||||
@@ -42,6 +42,9 @@ The following releases are verified and recommended:
|
||||
|
||||
## News
|
||||
|
||||
- 2025.11
|
||||
- Support malloc memory on device more than 4GB.
|
||||
|
||||
- 2025.2
|
||||
- Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC).
|
||||
|GPU|Base tokens/s|Increased tokens/s|Percent|
|
||||
@@ -789,6 +792,8 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
|
||||
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
|
||||
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
|
||||
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|
|
||||
|
||||
|
||||
|
||||
## Known Issues
|
||||
@@ -835,6 +840,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
| The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.|
|
||||
| The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;<br>Alternatively, use more than one device to load model.|
|
||||
|
||||
- `ggml_backend_sycl_buffer_type_alloc_buffer: can't allocate 5000000000 Bytes of memory on device`
|
||||
|
||||
You need to enable to support 4GB memory malloc by:
|
||||
```
|
||||
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
```
|
||||
|
||||
### **GitHub contribution**:
|
||||
Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay.
|
||||
|
||||
|
||||
109
docs/ops.md
109
docs/ops.md
@@ -14,103 +14,108 @@ Legend:
|
||||
|
||||
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
|
||||
|-----------|------|------|------|------|------|------|------|------|------|
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | 🟡 | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
||||
21200
docs/ops/CPU.csv
21200
docs/ops/CPU.csv
File diff suppressed because it is too large
Load Diff
23118
docs/ops/CUDA.csv
23118
docs/ops/CUDA.csv
File diff suppressed because it is too large
Load Diff
7174
docs/ops/SYCL.csv
7174
docs/ops/SYCL.csv
File diff suppressed because it is too large
Load Diff
18908
docs/ops/Vulkan.csv
18908
docs/ops/Vulkan.csv
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,7 @@
|
||||
The example demonstrates batched generation from a given prompt
|
||||
|
||||
```bash
|
||||
./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4
|
||||
./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4 --kv-unified
|
||||
|
||||
...
|
||||
|
||||
|
||||
@@ -6,8 +6,54 @@ More Info:
|
||||
- https://github.com/ggml-org/llama.cpp/pull/14644
|
||||
- https://github.com/ggml-org/llama.cpp/pull/14771
|
||||
|
||||
## Parameters
|
||||
The diffusion CLI supports various parameters to control the generation process:
|
||||
|
||||
Example of using Dream architechture: `llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual`
|
||||
### Core Diffusion Parameters
|
||||
- `--diffusion-steps`: Number of diffusion steps (default: 256)
|
||||
- `--diffusion-algorithm`: Algorithm for token selection
|
||||
- `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
|
||||
- `1`: ENTROPY_BASED - Entropy-based selection
|
||||
- `2`: MARGIN_BASED - Margin-based selection
|
||||
- `3`: RANDOM - Random selection
|
||||
- `4`: CONFIDENCE_BASED - Confidence-based selection (default)
|
||||
- More documentation here https://github.com/DreamLM/Dream
|
||||
- `--diffusion-visual`: Enable live visualization during generation
|
||||
|
||||
Example of using LLaDA architechture: `llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual`
|
||||
### Scheduling Parameters
|
||||
Choose one of the following scheduling methods:
|
||||
|
||||
**Timestep-based scheduling:**
|
||||
- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001)
|
||||
|
||||
**Block-based scheduling:**
|
||||
- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32)
|
||||
|
||||
### Sampling Parameters
|
||||
- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random)
|
||||
- `--top-k`: Top-k filtering for sampling
|
||||
- `--top-p`: Top-p (nucleus) filtering for sampling
|
||||
- `--seed`: Random seed for reproducibility
|
||||
|
||||
### Model Parameters
|
||||
- `-m`: Path to the GGUF model file
|
||||
- `-p`: Input prompt text
|
||||
- `-ub`: Maximum sequence length (ubatch size)
|
||||
- `-c`: Context size
|
||||
- `-b`: Batch size
|
||||
|
||||
### Examples
|
||||
#### Dream architechture:
|
||||
```
|
||||
llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual
|
||||
```
|
||||
|
||||
#### LLaDA architechture:
|
||||
```
|
||||
llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual
|
||||
```
|
||||
|
||||
#### RND1 architecture:
|
||||
```
|
||||
llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001
|
||||
```
|
||||
|
||||
@@ -104,12 +104,16 @@ int main(int argc, char ** argv) {
|
||||
|
||||
params.embedding = true;
|
||||
|
||||
// get max number of sequences per batch
|
||||
const int n_seq_max = llama_max_parallel_sequences();
|
||||
|
||||
// if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
|
||||
// --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
|
||||
// in order to support any number of prompts
|
||||
if (params.n_parallel == 1) {
|
||||
LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
|
||||
params.kv_unified = true;
|
||||
params.n_parallel = n_seq_max;
|
||||
}
|
||||
|
||||
// utilize the full context
|
||||
@@ -123,9 +127,6 @@ int main(int argc, char ** argv) {
|
||||
params.n_ubatch = params.n_batch;
|
||||
}
|
||||
|
||||
// get max number of sequences per batch
|
||||
const int n_seq_max = llama_max_parallel_sequences();
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
|
||||
/**
|
||||
* This the arbitrary data which will be passed to each callback.
|
||||
@@ -37,23 +37,23 @@ static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
|
||||
return u.f;
|
||||
}
|
||||
|
||||
static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
|
||||
static float ggml_get_float_value(const uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
|
||||
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
|
||||
float v;
|
||||
if (type == GGML_TYPE_F16) {
|
||||
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
|
||||
v = ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]);
|
||||
} else if (type == GGML_TYPE_F32) {
|
||||
v = *(float *) &data[i];
|
||||
v = *(const float *) &data[i];
|
||||
} else if (type == GGML_TYPE_I64) {
|
||||
v = (float) *(int64_t *) &data[i];
|
||||
v = (float) *(const int64_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I32) {
|
||||
v = (float) *(int32_t *) &data[i];
|
||||
v = (float) *(const int32_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I16) {
|
||||
v = (float) *(int16_t *) &data[i];
|
||||
v = (float) *(const int16_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I8) {
|
||||
v = (float) *(int8_t *) &data[i];
|
||||
v = (float) *(const int8_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_BF16) {
|
||||
v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]);
|
||||
v = ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
@@ -231,9 +231,9 @@ DOT = '[^\\x0A\\x0D]'
|
||||
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
|
||||
|
||||
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
|
||||
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
|
||||
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]')
|
||||
GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
|
||||
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'}
|
||||
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'}
|
||||
|
||||
NON_LITERAL_SET = set('|.()[]{}*+?')
|
||||
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')
|
||||
|
||||
@@ -4,6 +4,11 @@ set -e
|
||||
|
||||
# First try command line argument, then environment variable, then file
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"
|
||||
|
||||
if [ -z "$MODEL_TESTING_PROMPT"]; then
|
||||
MODEL_TESTING_PROMPT="Hello, my name is"
|
||||
fi
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
@@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then
|
||||
fi
|
||||
|
||||
echo $CONVERTED_MODEL
|
||||
echo $MODEL_TESTING_PROMPT
|
||||
|
||||
cmake --build ../../build --target llama-logits -j8
|
||||
|
||||
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is"
|
||||
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"
|
||||
|
||||
@@ -184,8 +184,12 @@ model_name = os.path.basename(model_path)
|
||||
# of using AutoModelForCausalLM.
|
||||
print(f"Model class: {model.__class__.__name__}")
|
||||
|
||||
prompt = "Hello, my name is"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
device = next(model.parameters()).device
|
||||
if os.getenv("MODEL_TESTING_PROMPT"):
|
||||
prompt = os.getenv("MODEL_TESTING_PROMPT")
|
||||
else:
|
||||
prompt = "Hello, my name is"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||
|
||||
print(f"Input tokens: {input_ids}")
|
||||
print(f"Input text: {repr(prompt)}")
|
||||
|
||||
@@ -15,6 +15,9 @@ MODEL_FILE=models/llama-2-7b.Q4_0.gguf
|
||||
NGL=99
|
||||
CONTEXT=4096
|
||||
|
||||
#support malloc device memory more than 4GB.
|
||||
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
|
||||
if [ $# -gt 0 ]; then
|
||||
GGML_SYCL_DEVICE=$1
|
||||
echo "use $GGML_SYCL_DEVICE as main GPU"
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
# If you want more control, DPC++ Allows selecting a specific device through the
|
||||
# following environment variable
|
||||
#export ONEAPI_DEVICE_SELECTOR="level_zero:0"
|
||||
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
#export GGML_SYCL_DEBUG=1
|
||||
@@ -18,11 +18,14 @@ MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
|
||||
NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using.
|
||||
CONTEXT=4096
|
||||
|
||||
#support malloc device memory more than 4GB.
|
||||
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
|
||||
if [ $# -gt 0 ]; then
|
||||
GGML_SYCL_DEVICE=$1
|
||||
echo "Using $GGML_SYCL_DEVICE as the main GPU"
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
|
||||
else
|
||||
#use multiple GPUs with same max compute units
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT}
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
|
||||
fi
|
||||
|
||||
@@ -5,5 +5,7 @@
|
||||
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
|
||||
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
|
||||
|
||||
:: support malloc device memory more than 4GB.
|
||||
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
|
||||
.\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0
|
||||
|
||||
@@ -5,5 +5,7 @@
|
||||
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
|
||||
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
|
||||
|
||||
:: support malloc device memory more than 4GB.
|
||||
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
|
||||
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -e -ngl 99
|
||||
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -s 0 -e -ngl 99
|
||||
|
||||
@@ -25,16 +25,17 @@ if(GIT_EXE)
|
||||
)
|
||||
endif()
|
||||
|
||||
# Build the version string with optional dirty flag
|
||||
set(GGML_VERSION "${GGML_VERSION_BASE}")
|
||||
if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0)
|
||||
set(GGML_VERSION "${GGML_VERSION}-dirty")
|
||||
endif()
|
||||
|
||||
if(NOT GGML_BUILD_COMMIT)
|
||||
set(GGML_BUILD_COMMIT "unknown")
|
||||
endif()
|
||||
|
||||
# Build the commit string with optional dirty flag
|
||||
if(DEFINED GGML_GIT_DIRTY AND GGML_GIT_DIRTY EQUAL 1)
|
||||
set(GGML_BUILD_COMMIT "${GGML_BUILD_COMMIT}-dirty")
|
||||
endif()
|
||||
|
||||
include(CheckIncludeFileCXX)
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
@@ -182,6 +183,7 @@ endif()
|
||||
# ggml core
|
||||
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
|
||||
option(GGML_CPU "ggml: enable CPU backend" ON)
|
||||
option(GGML_SCHED_NO_REALLOC "ggml: disallow reallocations in ggml-alloc (for debugging)" OFF)
|
||||
|
||||
# 3rd party libs / backends
|
||||
option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)
|
||||
|
||||
@@ -8,7 +8,7 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
#define RPC_PROTO_MAJOR_VERSION 3
|
||||
#define RPC_PROTO_MINOR_VERSION 0
|
||||
#define RPC_PROTO_MINOR_VERSION 5
|
||||
#define RPC_PROTO_PATCH_VERSION 0
|
||||
#define GGML_RPC_MAX_SERVERS 16
|
||||
|
||||
|
||||
@@ -475,6 +475,7 @@ extern "C" {
|
||||
GGML_OP_COS,
|
||||
GGML_OP_SUM,
|
||||
GGML_OP_SUM_ROWS,
|
||||
GGML_OP_CUMSUM,
|
||||
GGML_OP_MEAN,
|
||||
GGML_OP_ARGMAX,
|
||||
GGML_OP_COUNT_EQUAL,
|
||||
@@ -529,7 +530,10 @@ extern "C" {
|
||||
GGML_OP_ARANGE,
|
||||
GGML_OP_TIMESTEP_EMBEDDING,
|
||||
GGML_OP_ARGSORT,
|
||||
GGML_OP_TOP_K,
|
||||
GGML_OP_LEAKY_RELU,
|
||||
GGML_OP_TRI,
|
||||
GGML_OP_FILL,
|
||||
|
||||
GGML_OP_FLASH_ATTN_EXT,
|
||||
GGML_OP_FLASH_ATTN_BACK,
|
||||
@@ -542,6 +546,7 @@ extern "C" {
|
||||
GGML_OP_RWKV_WKV6,
|
||||
GGML_OP_GATED_LINEAR_ATTN,
|
||||
GGML_OP_RWKV_WKV7,
|
||||
GGML_OP_SOLVE_TRI,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
@@ -576,6 +581,8 @@ extern "C" {
|
||||
GGML_UNARY_OP_HARDSWISH,
|
||||
GGML_UNARY_OP_HARDSIGMOID,
|
||||
GGML_UNARY_OP_EXP,
|
||||
GGML_UNARY_OP_EXPM1,
|
||||
GGML_UNARY_OP_SOFTPLUS,
|
||||
GGML_UNARY_OP_GELU_ERF,
|
||||
GGML_UNARY_OP_XIELU,
|
||||
GGML_UNARY_OP_FLOOR,
|
||||
@@ -620,6 +627,13 @@ extern "C" {
|
||||
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
||||
};
|
||||
|
||||
enum ggml_tri_type {
|
||||
GGML_TRI_TYPE_UPPER_DIAG = 0,
|
||||
GGML_TRI_TYPE_UPPER = 1,
|
||||
GGML_TRI_TYPE_LOWER_DIAG = 2,
|
||||
GGML_TRI_TYPE_LOWER = 3
|
||||
};
|
||||
|
||||
struct ggml_init_params {
|
||||
// memory pool
|
||||
size_t mem_size; // bytes
|
||||
@@ -957,6 +971,22 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_expm1(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_expm1_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_softplus(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_softplus_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_sin(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
@@ -983,6 +1013,10 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_cumsum(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// mean along rows
|
||||
GGML_API struct ggml_tensor * ggml_mean(
|
||||
struct ggml_context * ctx,
|
||||
@@ -2114,7 +2148,8 @@ extern "C" {
|
||||
};
|
||||
|
||||
enum ggml_scale_flag {
|
||||
GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8)
|
||||
GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8),
|
||||
GGML_SCALE_FLAG_ANTIALIAS = (1 << 9),
|
||||
};
|
||||
|
||||
// interpolate
|
||||
@@ -2187,6 +2222,23 @@ extern "C" {
|
||||
int shift2,
|
||||
int shift3);
|
||||
|
||||
// Convert matrix into a triangular one (upper, strict upper, lower or strict lower) by writing
|
||||
// zeroes everywhere outside the masked area
|
||||
GGML_API struct ggml_tensor * ggml_tri(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_tri_type type);
|
||||
|
||||
// Fill tensor a with constant c
|
||||
GGML_API struct ggml_tensor * ggml_fill(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float c);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_fill_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float c);
|
||||
|
||||
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
||||
// timesteps: [N,]
|
||||
@@ -2208,18 +2260,25 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_sort_order order);
|
||||
|
||||
// similar to ggml_top_k but implemented as `argsort` + `view`
|
||||
GGML_API struct ggml_tensor * ggml_argsort_top_k(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int k);
|
||||
|
||||
// top k elements per row
|
||||
// note: the resulting top k indices are in no particular order
|
||||
GGML_API struct ggml_tensor * ggml_top_k(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int k);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_arange(
|
||||
struct ggml_context * ctx,
|
||||
float start,
|
||||
float stop,
|
||||
float step);
|
||||
|
||||
// top k elements per row
|
||||
GGML_API struct ggml_tensor * ggml_top_k(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int k);
|
||||
|
||||
#define GGML_KQ_MASK_PAD 64
|
||||
|
||||
// q: [n_embd_k, n_batch, n_head, ne3 ]
|
||||
@@ -2356,6 +2415,27 @@ extern "C" {
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * state);
|
||||
|
||||
/* Solves a specific equation of the form Ax=B, where A is a triangular matrix
|
||||
* without zeroes on the diagonal (i.e. invertible).
|
||||
* B can have any number of columns, but must have the same number of rows as A
|
||||
* If A is [n, n] and B is [n, m], then the result will be [n, m] as well
|
||||
* Has O(n^3) complexity (unlike most matrix ops out there), so use on cases
|
||||
* where n > 100 sparingly, pre-chunk if necessary.
|
||||
*
|
||||
* If left = false, solves xA=B instead
|
||||
* If lower = false, assumes upper triangular instead
|
||||
* If uni = true, assumes diagonal of A to be all ones (will override actual values)
|
||||
*
|
||||
* TODO: currently only lower, right, non-unitriangular variant is implemented
|
||||
*/
|
||||
GGML_API struct ggml_tensor * ggml_solve_tri(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
bool left,
|
||||
bool lower,
|
||||
bool uni);
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
||||
|
||||
@@ -221,6 +221,10 @@ if (GGML_BACKEND_DL)
|
||||
target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
|
||||
endif()
|
||||
|
||||
if (GGML_SCHED_NO_REALLOC)
|
||||
target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC)
|
||||
endif()
|
||||
|
||||
add_library(ggml
|
||||
ggml-backend-reg.cpp)
|
||||
add_library(ggml::ggml ALIAS ggml)
|
||||
@@ -270,10 +274,13 @@ function(ggml_add_backend_library backend)
|
||||
endif()
|
||||
|
||||
# Set versioning properties for all backend libraries
|
||||
set_target_properties(${backend} PROPERTIES
|
||||
VERSION ${GGML_VERSION}
|
||||
SOVERSION ${GGML_VERSION_MAJOR}
|
||||
)
|
||||
# Building a MODULE library with a version is not supported on macOS (https://gitlab.kitware.com/cmake/cmake/-/issues/20782)
|
||||
if (NOT (APPLE AND GGML_BACKEND_DL))
|
||||
set_target_properties(${backend} PROPERTIES
|
||||
VERSION ${GGML_VERSION}
|
||||
SOVERSION ${GGML_VERSION_MAJOR}
|
||||
)
|
||||
endif()
|
||||
|
||||
if(NOT GGML_AVAILABLE_BACKENDS)
|
||||
set(GGML_AVAILABLE_BACKENDS "${backend}"
|
||||
@@ -328,6 +335,14 @@ function(ggml_add_cpu_backend_variant tag_name)
|
||||
set(GGML_INTERNAL_${feat} OFF)
|
||||
endforeach()
|
||||
|
||||
foreach (feat ${ARGN})
|
||||
set(GGML_INTERNAL_${feat} ON)
|
||||
endforeach()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
|
||||
foreach (feat RVV)
|
||||
set(GGML_INTERNAL_${feat} OFF)
|
||||
endforeach()
|
||||
|
||||
foreach (feat ${ARGN})
|
||||
set(GGML_INTERNAL_${feat} ON)
|
||||
endforeach()
|
||||
@@ -402,6 +417,13 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
ggml_add_cpu_backend_variant(riscv64_0)
|
||||
ggml_add_cpu_backend_variant(riscv64_v RVV)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported RISC-V target OS: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
|
||||
@@ -921,10 +921,15 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
||||
}
|
||||
if (realloc) {
|
||||
#ifndef NDEBUG
|
||||
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
|
||||
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
||||
{
|
||||
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
|
||||
if (cur_size > 0) {
|
||||
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n",
|
||||
__func__, ggml_backend_buft_name(galloc->bufts[i]),
|
||||
cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
ggml_vbuffer_free(galloc->buffers[i]);
|
||||
galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||
if (galloc->buffers[i] == NULL) {
|
||||
|
||||
@@ -723,6 +723,12 @@ struct ggml_backend_sched {
|
||||
bool op_offload;
|
||||
|
||||
int debug;
|
||||
|
||||
// used for debugging graph reallocations [GGML_SCHED_DEBUG_REALLOC]
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/17617
|
||||
int debug_realloc;
|
||||
int debug_graph_size;
|
||||
int debug_prev_graph_size;
|
||||
};
|
||||
|
||||
#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
|
||||
@@ -1289,6 +1295,11 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
|
||||
}
|
||||
|
||||
int graph_size = std::max(graph->n_nodes, graph->n_leafs) + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sched->n_copies;
|
||||
|
||||
// remember the actual graph_size for performing reallocation checks later [GGML_SCHED_DEBUG_REALLOC]
|
||||
sched->debug_prev_graph_size = sched->debug_graph_size;
|
||||
sched->debug_graph_size = graph_size;
|
||||
|
||||
if (sched->graph.size < graph_size) {
|
||||
sched->graph.size = graph_size;
|
||||
sched->graph.nodes = (ggml_tensor **) realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *));
|
||||
@@ -1395,14 +1406,27 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
|
||||
|
||||
// allocate graph
|
||||
if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
|
||||
#endif
|
||||
|
||||
if (sched->debug_realloc > 0) {
|
||||
// we are interested only in situations where the graph was reallocated even though its size remained the same [GGML_SCHED_DEBUG_REALLOC]
|
||||
// example: https://github.com/ggml-org/llama.cpp/pull/17143
|
||||
const bool unexpected = !backend_ids_changed && sched->debug_prev_graph_size == sched->debug_graph_size;
|
||||
|
||||
if (unexpected || sched->debug_realloc > 1) {
|
||||
GGML_ABORT("%s: unexpected graph reallocation (graph size = %d, nodes = %d, leafs = %d), debug_realloc = %d\n", __func__,
|
||||
sched->debug_graph_size, sched->graph.n_nodes, sched->graph.n_leafs, sched->debug_realloc);
|
||||
}
|
||||
}
|
||||
|
||||
// the re-allocation may cause the split inputs to be moved to a different address
|
||||
// synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy
|
||||
for (int i = 0; i < sched->n_backends; i++) {
|
||||
ggml_backend_synchronize(sched->backends[i]);
|
||||
}
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
|
||||
#endif
|
||||
|
||||
ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
|
||||
if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
||||
@@ -1614,6 +1638,14 @@ ggml_backend_sched_t ggml_backend_sched_new(
|
||||
|
||||
const char * GGML_SCHED_DEBUG = getenv("GGML_SCHED_DEBUG");
|
||||
sched->debug = GGML_SCHED_DEBUG ? atoi(GGML_SCHED_DEBUG) : 0;
|
||||
|
||||
sched->debug_realloc = 0;
|
||||
#ifdef GGML_SCHED_NO_REALLOC
|
||||
sched->debug_realloc = 1;
|
||||
#endif
|
||||
const char * GGML_SCHED_DEBUG_REALLOC = getenv("GGML_SCHED_DEBUG_REALLOC");
|
||||
sched->debug_realloc = GGML_SCHED_DEBUG_REALLOC ? atoi(GGML_SCHED_DEBUG_REALLOC) : sched->debug_realloc;
|
||||
|
||||
sched->n_backends = n_backends;
|
||||
sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
|
||||
|
||||
@@ -1630,6 +1662,9 @@ ggml_backend_sched_t ggml_backend_sched_new(
|
||||
sched->prev_node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0]));
|
||||
sched->prev_leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0]));
|
||||
|
||||
sched->debug_graph_size = 0;
|
||||
sched->debug_prev_graph_size = 0;
|
||||
|
||||
sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false);
|
||||
sched->context_buffer = (char *) malloc(sched->context_buffer_size);
|
||||
|
||||
@@ -1698,8 +1733,6 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
|
||||
GGML_ASSERT(sched);
|
||||
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
|
||||
|
||||
ggml_backend_sched_reset(sched);
|
||||
|
||||
ggml_backend_sched_synchronize(sched);
|
||||
|
||||
ggml_backend_sched_split_graph(sched, measure_graph);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -48,15 +48,14 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
|
||||
default:
|
||||
return ACL_DT_UNDEFINED;
|
||||
}
|
||||
return ACL_DT_UNDEFINED;
|
||||
}
|
||||
|
||||
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne,
|
||||
size_t * nb,
|
||||
int64_t dims,
|
||||
aclFormat format,
|
||||
size_t offset) {
|
||||
acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne,
|
||||
size_t * nb,
|
||||
int64_t dims,
|
||||
aclFormat format,
|
||||
size_t offset) {
|
||||
// If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be
|
||||
// added.
|
||||
int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2];
|
||||
@@ -87,10 +86,20 @@ aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
std::reverse(acl_ne, acl_ne + final_dims);
|
||||
std::reverse(acl_stride, acl_stride + final_dims);
|
||||
|
||||
aclTensor * acl_tensor = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
|
||||
elem_offset, format, &acl_storage_len, 1, tensor->data);
|
||||
aclTensor * raw = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride, elem_offset,
|
||||
format, &acl_storage_len, 1, tensor->data);
|
||||
|
||||
return acl_tensor;
|
||||
return acl_tensor_ptr(raw);
|
||||
}
|
||||
|
||||
acl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size) {
|
||||
aclIntArray * raw = aclCreateIntArray(value, size);
|
||||
return acl_int_array_ptr(raw);
|
||||
}
|
||||
|
||||
acl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType) {
|
||||
aclScalar * raw = aclCreateScalar(value, dataType);
|
||||
return acl_scalar_ptr(raw);
|
||||
}
|
||||
|
||||
bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1) {
|
||||
|
||||
@@ -23,11 +23,12 @@
|
||||
#ifndef CANN_ACL_TENSOR_H
|
||||
#define CANN_ACL_TENSOR_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include "common.h"
|
||||
|
||||
#include <aclnn/aclnn_base.h>
|
||||
#include "common.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
/**
|
||||
* @brief Maps a ggml_type to its corresponding aclDataType.
|
||||
@@ -43,6 +44,20 @@
|
||||
*/
|
||||
aclDataType ggml_cann_type_mapping(ggml_type type);
|
||||
|
||||
// Deleter for acl objects.
|
||||
template <typename T, aclError (*DestroyFunc)(const T *)> struct acl_deleter {
|
||||
void operator()(T * ptr) const noexcept {
|
||||
if (ptr) {
|
||||
ACL_CHECK(DestroyFunc(ptr));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using acl_tensor_ptr = std::unique_ptr<aclTensor, acl_deleter<aclTensor, aclDestroyTensor>>;
|
||||
using acl_int_array_ptr = std::unique_ptr<aclIntArray, acl_deleter<aclIntArray, aclDestroyIntArray>>;
|
||||
using acl_scalar_ptr = std::unique_ptr<aclScalar, acl_deleter<aclScalar, aclDestroyScalar>>;
|
||||
using acl_tensor_list_ptr = std::unique_ptr<aclTensorList, acl_deleter<aclTensorList, aclDestroyTensorList>>;
|
||||
|
||||
/**
|
||||
* @brief Creates an ACL tensor from a ggml_tensor with optional shape.
|
||||
*
|
||||
@@ -62,12 +77,12 @@ aclDataType ggml_cann_type_mapping(ggml_type type);
|
||||
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
|
||||
* @return Pointer to the created ACL tensor.
|
||||
*/
|
||||
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne = nullptr,
|
||||
size_t * nb = nullptr,
|
||||
int64_t dims = 0,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0);
|
||||
acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne = nullptr,
|
||||
size_t * nb = nullptr,
|
||||
int64_t dims = 0,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0);
|
||||
|
||||
/**
|
||||
* @brief Template for creating an ACL tensor from provided parameters. typename TYPE
|
||||
@@ -90,14 +105,14 @@ aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
* @return Pointer to the created ACL tensor.
|
||||
*/
|
||||
template <typename TYPE>
|
||||
aclTensor * ggml_cann_create_tensor(void * data_ptr,
|
||||
aclDataType dtype,
|
||||
TYPE type_size,
|
||||
int64_t * ne,
|
||||
TYPE * nb,
|
||||
int64_t dims,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0) {
|
||||
acl_tensor_ptr ggml_cann_create_tensor(void * data_ptr,
|
||||
aclDataType dtype,
|
||||
TYPE type_size,
|
||||
int64_t * ne,
|
||||
TYPE * nb,
|
||||
int64_t dims,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0) {
|
||||
int64_t tmp_ne[GGML_MAX_DIMS * 2];
|
||||
int64_t tmp_stride[GGML_MAX_DIMS * 2];
|
||||
|
||||
@@ -114,10 +129,75 @@ aclTensor * ggml_cann_create_tensor(void * data_ptr,
|
||||
std::reverse(tmp_ne, tmp_ne + dims);
|
||||
std::reverse(tmp_stride, tmp_stride + dims);
|
||||
|
||||
aclTensor * acl_tensor =
|
||||
aclTensor * raw =
|
||||
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr);
|
||||
|
||||
return acl_tensor;
|
||||
return acl_tensor_ptr(raw);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create an ACL int array resource wrapped in a smart pointer.
|
||||
*
|
||||
* This function constructs an aclIntArray from the provided int64_t values
|
||||
* and returns it as an acl_int_array_ptr (a std::unique_ptr with a custom
|
||||
* deleter). The returned pointer owns the ACL resource and will automatically
|
||||
* destroy it via aclDestroyIntArray().
|
||||
*
|
||||
* @param value Pointer to the int64_t elements.
|
||||
* @param size Number of elements in value.
|
||||
*
|
||||
* @return A smart pointer managing the created ACL int array.
|
||||
*/
|
||||
acl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size);
|
||||
|
||||
/**
|
||||
* @brief Create an ACL scalar resource wrapped in a smart pointer.
|
||||
*
|
||||
* This function constructs an aclScalar from the raw value pointer and ACL
|
||||
* data type, then returns it as an acl_scalar_ptr (a std::unique_ptr with
|
||||
* a custom deleter). The returned pointer owns the ACL scalar and will
|
||||
* automatically destroy it via aclDestroyScalar().
|
||||
*
|
||||
* @param value Pointer to the raw scalar memory.
|
||||
* @param dataType ACL data type of the scalar.
|
||||
*
|
||||
* @return A smart pointer managing the created ACL scalar.
|
||||
*/
|
||||
acl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType);
|
||||
|
||||
/**
|
||||
* @brief Create an ACL tensor list from multiple tensor smart pointers.
|
||||
*
|
||||
* This function accepts a variadic list of acl_tensor_ptr (a unique_ptr with
|
||||
* custom deleter) and produces an aclTensorList using aclCreateTensorList().
|
||||
*
|
||||
* The lifecycle management of the tensor objects changes as follows:
|
||||
* - aclCreateTensorList() takes ownership of the tensors
|
||||
* - Each input smart pointer releases ownership using release()
|
||||
* - As a result, the tensors will NOT be destroyed by unique_ptr
|
||||
* - Instead, they will be destroyed when aclDestroyTensorList() is called
|
||||
*
|
||||
* This ensures correct ownership transfer and prevents double-free situations.
|
||||
*
|
||||
* @param acl_tensor_ptr Variadic template parameter; each argument must be
|
||||
* a unique_ptr-like type supporting get() and release().
|
||||
*
|
||||
* @param tensors Variadic list of acl_tensor_ptr objects. Ownership of
|
||||
* each tensor is transferred away from these smart pointers.
|
||||
*
|
||||
* @return A smart pointer (acl_tensor_list_ptr) owning the created ACL tensor list.
|
||||
*
|
||||
* @note This implementation is C++11 compatible. The ownership-release process is
|
||||
* executed using a pack expansion inside an initializer list.
|
||||
*/
|
||||
template <typename... acl_tensor_ptr> acl_tensor_list_ptr ggml_cann_create_tensor_list(acl_tensor_ptr &&... tensors) {
|
||||
aclTensor * raw_tensors[] = { tensors.get()... };
|
||||
aclTensorList * raw = aclCreateTensorList(raw_tensors, sizeof...(tensors));
|
||||
// aclTensor will release by aclTensorList, so release ownership without
|
||||
// destroying the tensor
|
||||
int dummy[] = { (tensors.release(), 0)... };
|
||||
GGML_UNUSED(dummy);
|
||||
return acl_tensor_list_ptr(raw);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,32 +23,35 @@
|
||||
#ifndef CANN_ACLNN_OPS
|
||||
#define CANN_ACLNN_OPS
|
||||
|
||||
#include <unordered_set>
|
||||
#include <functional>
|
||||
#include "acl_tensor.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <aclnnop/aclnn_abs.h>
|
||||
#include <aclnnop/aclnn_neg.h>
|
||||
#include <aclnnop/aclnn_exp.h>
|
||||
#include <aclnnop/aclnn_arange.h>
|
||||
#include <aclnnop/aclnn_argsort.h>
|
||||
#include <aclnnop/aclnn_cat.h>
|
||||
#include <aclnnop/aclnn_clamp.h>
|
||||
#include <aclnnop/aclnn_cos.h>
|
||||
#include <aclnnop/aclnn_exp.h>
|
||||
#include <aclnnop/aclnn_gelu.h>
|
||||
#include <aclnnop/aclnn_gelu_v2.h>
|
||||
#include <aclnnop/aclnn_sigmoid.h>
|
||||
#include <aclnnop/aclnn_hardsigmoid.h>
|
||||
#include <aclnnop/aclnn_hardswish.h>
|
||||
#include <aclnnop/aclnn_leaky_relu.h>
|
||||
#include <aclnnop/aclnn_relu.h>
|
||||
#include <aclnnop/aclnn_silu.h>
|
||||
#include <aclnnop/aclnn_tanh.h>
|
||||
#include <aclnnop/aclnn_sqrt.h>
|
||||
#include <aclnnop/aclnn_sin.h>
|
||||
#include <aclnnop/aclnn_cos.h>
|
||||
#include <aclnnop/aclnn_log.h>
|
||||
#include <aclnnop/aclnn_sign.h>
|
||||
#include <aclnnop/aclnn_logsoftmax.h>
|
||||
#include <aclnnop/aclnn_neg.h>
|
||||
#include <aclnnop/aclnn_norm.h>
|
||||
#include "acl_tensor.h"
|
||||
#include "common.h"
|
||||
#include <aclnnop/aclnn_relu.h>
|
||||
#include <aclnnop/aclnn_sigmoid.h>
|
||||
#include <aclnnop/aclnn_sign.h>
|
||||
#include <aclnnop/aclnn_silu.h>
|
||||
#include <aclnnop/aclnn_sin.h>
|
||||
#include <aclnnop/aclnn_sqrt.h>
|
||||
#include <aclnnop/aclnn_tanh.h>
|
||||
|
||||
#include <functional>
|
||||
#include <unordered_set>
|
||||
|
||||
/**
|
||||
* @brief Repeats a ggml tensor along each dimension to match the dimensions
|
||||
@@ -211,6 +214,43 @@ void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
*/
|
||||
void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the Cross Entropy Loss for a ggml tensor using the CANN
|
||||
* backend.
|
||||
*
|
||||
* @details This function computes the cross entropy loss between the predicted
|
||||
* logits and target probability distributions. The operation follows
|
||||
* the same computation pattern as the CPU implementation:
|
||||
* 1. Applies log_softmax to the logits along the class dimension
|
||||
* 2. Element-wise multiplication with target distributions
|
||||
* 3. Summation along the class dimension to get per-sample losses
|
||||
* 4. Global summation and scaling by -1/nr to get final loss
|
||||
*
|
||||
* The computation can be expressed as:
|
||||
* \f[
|
||||
* \text{loss} = -\frac{1}{N} \sum_{i=1}^{N} \sum_{j=1}^{C} y_{ij} \cdot \log(\text{softmax}(x_{ij}))
|
||||
* \f]
|
||||
* where \f$N\f$ is the total number of samples, \f$C\f$ is the number
|
||||
* of classes, \f$x\f$ are the logits, and \f$y\f$ are the target
|
||||
* probability distributions.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the computed loss will be stored.
|
||||
* This should be a scalar tensor containing the final loss value.
|
||||
*
|
||||
* @note This implementation computes cross entropy between probability
|
||||
* distributions, not the typical classification cross entropy that
|
||||
* expects class indices as targets. Both input tensors (src0 and src1)
|
||||
* should have the same shape and represent probability distributions
|
||||
* over the class dimension.
|
||||
* @note The function expects two source tensors:
|
||||
* - dst->src[0]: Logits tensor (before softmax)
|
||||
* - dst->src[1]: Target probability distributions tensor
|
||||
* @note The computation is performed using CANN backend operators including
|
||||
* LogSoftmax, Mul, ReduceSum, and Muls for the final scaling.
|
||||
*/
|
||||
void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the Group Normalization for a ggml tensor using the CANN
|
||||
* backend.
|
||||
@@ -650,12 +690,12 @@ void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor *
|
||||
* @param acl_src1 Output pointer to the created ACL tensor corresponding to src1.
|
||||
* @param acl_dst Output pointer to the created ACL tensor corresponding to dst.
|
||||
*/
|
||||
void bcast_shape(ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
aclTensor ** acl_src0,
|
||||
aclTensor ** acl_src1,
|
||||
aclTensor ** acl_dst);
|
||||
void bcast_shape(ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
acl_tensor_ptr & acl_src0,
|
||||
acl_tensor_ptr & acl_src1,
|
||||
acl_tensor_ptr & acl_dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the 1D transposed convolution (deconvolution) of a ggml
|
||||
@@ -835,83 +875,6 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
|
||||
(vec.emplace_back(make_acl_resource(args)), ...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Task class that wraps the execution of an aclnn function call.
|
||||
*/
|
||||
class aclnn_task : public cann_task {
|
||||
public:
|
||||
aclnn_task(aclnn_func_t aclnn_func,
|
||||
void * workspace_addr,
|
||||
uint64_t workspace_size,
|
||||
aclOpExecutor * executor,
|
||||
aclrtStream stream) :
|
||||
aclnn_func_(aclnn_func),
|
||||
workspace_addr_(workspace_addr),
|
||||
workspace_size_(workspace_size),
|
||||
executor_(executor),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override { ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_)); }
|
||||
private:
|
||||
aclnn_func_t aclnn_func_;
|
||||
void * workspace_addr_;
|
||||
uint64_t workspace_size_;
|
||||
aclOpExecutor * executor_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class that releases ACL resources after usage.
|
||||
*/
|
||||
class release_resource_task : public cann_task {
|
||||
public:
|
||||
release_resource_task(std::vector<any_acl_resource> && resources) { resource_ = std::move(resources); }
|
||||
|
||||
virtual void run_task() override { resource_.clear(); }
|
||||
private:
|
||||
std::vector<any_acl_resource> resource_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class for performing asynchronous memory copy operations.
|
||||
*/
|
||||
class async_memcpy_task : public cann_task {
|
||||
public:
|
||||
async_memcpy_task(void * dst, const void * src, size_t size, aclrtMemcpyKind kind, aclrtStream stream) :
|
||||
dst_(dst),
|
||||
src_(src),
|
||||
size_(size),
|
||||
kind_(kind),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override { ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_)); }
|
||||
private:
|
||||
void * dst_;
|
||||
const void * src_;
|
||||
size_t size_;
|
||||
aclrtMemcpyKind kind_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class for performing asynchronous memory set operations.
|
||||
*/
|
||||
class async_memset_task : public cann_task {
|
||||
public:
|
||||
async_memset_task(void * buffer, size_t size, int32_t value, aclrtStream stream) :
|
||||
buffer_(buffer),
|
||||
size_(size),
|
||||
value_(value),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override { ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_)); }
|
||||
private:
|
||||
void * buffer_;
|
||||
size_t size_;
|
||||
int32_t value_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Launches an asynchronous task using the memory allocator.
|
||||
*
|
||||
@@ -930,95 +893,20 @@ class async_memset_task : public cann_task {
|
||||
* same stream are executed in queue order.
|
||||
*/
|
||||
|
||||
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
|
||||
do { \
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor * executor; \
|
||||
void * workspaceAddr = nullptr; \
|
||||
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
|
||||
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
|
||||
if (workspaceSize > 0) { \
|
||||
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
|
||||
workspaceAddr = workspace_allocator.get(); \
|
||||
} \
|
||||
if (CTX.async_mode) { \
|
||||
auto task = \
|
||||
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, executor, CTX.stream()); \
|
||||
CTX.task_queue.submit_task(std::move(task)); \
|
||||
} else { \
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
|
||||
} \
|
||||
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
|
||||
do { \
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor * executor; \
|
||||
void * workspaceAddr = nullptr; \
|
||||
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
|
||||
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
|
||||
if (workspaceSize > 0) { \
|
||||
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
|
||||
workspaceAddr = workspace_allocator.get(); \
|
||||
} \
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
|
||||
} while (0)
|
||||
|
||||
/**
|
||||
* @brief Registers and releases multiple ACL resources, optionally deferring the release
|
||||
* using a task.
|
||||
*
|
||||
* @tparam Args Types of the ACL resources.
|
||||
* @param ctx Backend context which manages task submission and async mode.
|
||||
* @param args Pointers to ACL resources to be released.
|
||||
*/
|
||||
template <typename... Args> void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
|
||||
std::vector<any_acl_resource> resources;
|
||||
register_acl_resources(resources, std::forward<Args>(args)...);
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<release_resource_task>(std::move(resources));
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs an asynchronous memory copy operation, optionally deferred via task submission.
|
||||
*
|
||||
* @param ctx Backend context containing stream and async configuration.
|
||||
* @param dst Destination memory address.
|
||||
* @param src Source memory address.
|
||||
* @param len Size of memory to copy (in bytes).
|
||||
* @param kind Type of memory copy (host-to-device, device-to-host, etc).
|
||||
*/
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx,
|
||||
void * dst,
|
||||
const void * src,
|
||||
size_t len,
|
||||
aclrtMemcpyKind kind) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx.stream());
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream()));
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx,
|
||||
void * dst,
|
||||
const void * src,
|
||||
size_t len,
|
||||
aclrtMemcpyKind kind) {
|
||||
if (ctx->async_mode) {
|
||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx->stream());
|
||||
ctx->task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs an asynchronous memory set operation, optionally deferred via task submission.
|
||||
*
|
||||
* @param ctx Backend context containing stream and async configuration.
|
||||
* @param buffer Memory buffer to be set.
|
||||
* @param size Size of the memory buffer (in bytes).
|
||||
* @param value Value to set in the buffer.
|
||||
*/
|
||||
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer, size_t size, int value) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<async_memset_task>(buffer, size, value, ctx.stream());
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs sparse expert-based matrix multiplication using the CANN backend.
|
||||
*
|
||||
@@ -1091,15 +979,11 @@ template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & c
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
aclTensor * acl_src0;
|
||||
aclTensor * acl_src1;
|
||||
aclTensor * acl_dst;
|
||||
acl_tensor_ptr acl_src0, acl_src1, acl_dst;
|
||||
|
||||
// Need bcast
|
||||
bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
|
||||
binary_op(ctx, acl_src0, acl_src1, acl_dst);
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
|
||||
bcast_shape(src0, src1, dst, acl_src0, acl_src1, acl_dst);
|
||||
binary_op(ctx, acl_src0.get(), acl_src1.get(), acl_dst.get());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1109,7 +993,7 @@ template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & c
|
||||
* and stores the result in the destination tensor.
|
||||
*
|
||||
* @tparam unary_op A callable with the signature:
|
||||
* void(ggml_backend_cann_context&, aclTensor*, aclTensor*)
|
||||
* void(ggml_backend_cann_context&, aclTensor *, aclTensor *)
|
||||
* where the first aclTensor is the source and the second is the destination.
|
||||
* @param ctx The CANN backend context for managing resources and execution.
|
||||
* @param dst The destination tensor. Its src[0] is treated as the input tensor.
|
||||
@@ -1118,11 +1002,10 @@ template <void unary_op(ggml_backend_cann_context &, aclTensor *, aclTensor *)>
|
||||
void ggml_cann_op_unary(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src = dst->src[0];
|
||||
|
||||
aclTensor * acl_src = ggml_cann_create_tensor(src);
|
||||
aclTensor * acl_dst = ggml_cann_create_tensor(dst);
|
||||
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
unary_op(ctx, acl_src, acl_dst);
|
||||
ggml_cann_release_resources(ctx, acl_src, acl_dst);
|
||||
unary_op(ctx, acl_src.get(), acl_dst.get());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1242,3 +1125,23 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
|
||||
} while (0)
|
||||
|
||||
#endif // CANN_ACLNN_OPS
|
||||
|
||||
/**
|
||||
* @brief Performs outer product operation on two ggml tensors using the CANN backend.
|
||||
*
|
||||
* @details This function computes the outer product of two input tensors (src0 and src1)
|
||||
* and stores the result in the destination tensor. The outer product operation is defined as:
|
||||
* dst[i,j,k,l] = sum_m (src0[i,m,k,l] * src1[j,m,k,l])
|
||||
*
|
||||
* The function supports multiple data types including F32, F16. For floating-point
|
||||
* types, it uses batch matrix multiplication for efficient computation.
|
||||
*
|
||||
* The implementation handles 4D tensor broadcasting and batch processing automatically.
|
||||
*
|
||||
* @param ctx The CANN backend context for operation execution and memory management.
|
||||
* @param dst The destination ggml_tensor where the outer product result will be stored.
|
||||
* The input tensors are assumed to be `dst->src[0]` and `dst->src[1]`.
|
||||
*
|
||||
* @see GGML_CANN_CALL_ACLNN_OP for CANN operator invocation
|
||||
*/
|
||||
void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -23,26 +23,26 @@
|
||||
#ifndef CANN_COMMON_H
|
||||
#define CANN_COMMON_H
|
||||
|
||||
#include <acl/acl.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <list>
|
||||
|
||||
#include "../ggml-impl.h"
|
||||
#include "../include/ggml-cann.h"
|
||||
#include "../include/ggml.h"
|
||||
#include "../ggml-impl.h"
|
||||
|
||||
#include <acl/acl.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <cstdio>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#define MATRIX_ROW_PADDING 512
|
||||
#define GGML_CANN_MAX_STREAMS 8
|
||||
@@ -214,130 +214,6 @@ struct ggml_cann_pool_alloc {
|
||||
ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Function pointer type for ACLNN operator calls.
|
||||
*/
|
||||
using aclnn_func_t = aclnnStatus (*)(void *, uint64_t, aclOpExecutor *, aclrtStream);
|
||||
|
||||
/**
|
||||
* @brief Base class for all CANN tasks to be submitted to the task queue.
|
||||
*
|
||||
* Users should override the run_task() method with actual task logic.
|
||||
*/
|
||||
class cann_task {
|
||||
public:
|
||||
virtual void run_task() {}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
|
||||
*/
|
||||
class cann_task_queue {
|
||||
public:
|
||||
/**
|
||||
* @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
|
||||
*
|
||||
* @param capacity Queue capacity. Must be a power of 2.
|
||||
* @param device Target device ID (used for context setting).
|
||||
*/
|
||||
explicit cann_task_queue(size_t capacity, int32_t device) :
|
||||
buffer_(capacity),
|
||||
capacity_(capacity),
|
||||
head_(0),
|
||||
tail_(0),
|
||||
running_(false),
|
||||
device_(device) {
|
||||
GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
|
||||
mask_ = capacity_ - 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Attempts to enqueue a task into the queue.
|
||||
*
|
||||
* @param item Unique pointer to the task.
|
||||
* @return true if the task was successfully enqueued, false if the queue was full.
|
||||
*/
|
||||
bool enqueue(std::unique_ptr<cann_task> && item) {
|
||||
size_t next_tail = (tail_ + 1) & mask_;
|
||||
|
||||
if (next_tail == head_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
buffer_[tail_] = std::move(item);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
tail_ = next_tail;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Submits a task to the queue, and starts the worker thread if not already running.
|
||||
*
|
||||
* @param task Task to be submitted.
|
||||
*/
|
||||
void submit_task(std::unique_ptr<cann_task> && task) {
|
||||
while (!enqueue(std::move(task))) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!running_) {
|
||||
running_ = true;
|
||||
thread_ = std::thread(&cann_task_queue::execute, this);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Waits until the queue is completely empty and no tasks are being processed.
|
||||
*/
|
||||
void wait() {
|
||||
while (running_ && head_ != tail_) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Stops the task queue and joins the worker thread.
|
||||
*/
|
||||
void stop() {
|
||||
running_ = false;
|
||||
if (thread_.joinable()) {
|
||||
thread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Worker thread function that continuously dequeues and executes tasks.
|
||||
*/
|
||||
void execute() {
|
||||
ggml_cann_set_device(device_);
|
||||
|
||||
while (running_) {
|
||||
if (head_ == tail_) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
buffer_[head_]->run_task();
|
||||
buffer_[head_].reset();
|
||||
head_ = (head_ + 1) & mask_;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<cann_task>> buffer_;
|
||||
const size_t capacity_;
|
||||
size_t mask_;
|
||||
size_t head_;
|
||||
size_t tail_;
|
||||
bool running_;
|
||||
std::thread thread_;
|
||||
int32_t device_;
|
||||
};
|
||||
|
||||
#ifdef USE_ACL_GRAPH
|
||||
struct ggml_graph_node_properties {
|
||||
// dst tensor
|
||||
@@ -424,30 +300,92 @@ struct ggml_cann_graph_lru_cache {
|
||||
|
||||
struct ggml_cann_rope_cache {
|
||||
~ggml_cann_rope_cache() {
|
||||
if (theta_scale_cache != nullptr) {
|
||||
if (theta_scale_cache) {
|
||||
ACL_CHECK(aclrtFree(theta_scale_cache));
|
||||
}
|
||||
if (sin_cache != nullptr) {
|
||||
if (sin_cache) {
|
||||
ACL_CHECK(aclrtFree(sin_cache));
|
||||
}
|
||||
if (cos_cache != nullptr) {
|
||||
if (cos_cache) {
|
||||
ACL_CHECK(aclrtFree(cos_cache));
|
||||
}
|
||||
if (position_select_index) {
|
||||
ACL_CHECK(aclrtFree(position_select_index));
|
||||
}
|
||||
if (theta_scale_exp_host) {
|
||||
free(theta_scale_exp_host);
|
||||
}
|
||||
if(position_select_index_host) {
|
||||
free(position_select_index_host);
|
||||
}
|
||||
}
|
||||
|
||||
void * theta_scale_cache = nullptr;
|
||||
int64_t theta_scale_length = 0;
|
||||
bool equal(int64_t theta_scale_length,
|
||||
int64_t position_length,
|
||||
float ext_factor,
|
||||
float theta_scale,
|
||||
float freq_scale,
|
||||
float attn_factor,
|
||||
bool is_neox,
|
||||
bool indep_sects,
|
||||
bool mrope_used,
|
||||
bool is_imrope,
|
||||
int sections[4]) {
|
||||
return this->theta_scale_length == theta_scale_length && this->position_length == position_length &&
|
||||
this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale &&
|
||||
this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects &&
|
||||
this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] &&
|
||||
this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3];
|
||||
}
|
||||
|
||||
void set(int64_t theta_scale_length,
|
||||
int64_t position_length,
|
||||
float ext_factor,
|
||||
float theta_scale,
|
||||
float freq_scale,
|
||||
float attn_factor,
|
||||
bool is_neox,
|
||||
bool indep_sects,
|
||||
bool mrope_used,
|
||||
bool is_imrope,
|
||||
int sections[4]) {
|
||||
this->theta_scale_length = theta_scale_length;
|
||||
this->position_length = position_length;
|
||||
this->ext_factor = ext_factor;
|
||||
this->theta_scale = theta_scale;
|
||||
this->freq_scale = freq_scale;
|
||||
this->attn_factor = attn_factor;
|
||||
this->is_neox = is_neox;
|
||||
this->indep_sects = indep_sects;
|
||||
this->mrope_used = mrope_used;
|
||||
this->is_imrope = is_imrope;
|
||||
this->sections[0] = sections[0];
|
||||
this->sections[1] = sections[1];
|
||||
this->sections[2] = sections[2];
|
||||
this->sections[3] = sections[3];
|
||||
}
|
||||
|
||||
// memory cache, prepare before inferencing.
|
||||
void * theta_scale_cache = nullptr;
|
||||
float * theta_scale_exp_host = nullptr;
|
||||
int * position_select_index_host = nullptr;
|
||||
void * position_select_index = nullptr;
|
||||
// sin/cos cache, used only to accelerate first layer on each device
|
||||
void * sin_cache = nullptr;
|
||||
void * cos_cache = nullptr;
|
||||
int64_t position_length = 0;
|
||||
void * sin_cache = nullptr;
|
||||
void * cos_cache = nullptr;
|
||||
// Properties to check before reusing the sincos cache
|
||||
bool cached = false;
|
||||
float ext_factor = 0.0f;
|
||||
float theta_scale = 0.0f;
|
||||
float freq_scale = 0.0f;
|
||||
float attn_factor = 0.0f;
|
||||
bool is_neox = false;
|
||||
int64_t theta_scale_length = 0;
|
||||
int64_t position_length = 0;
|
||||
bool cached = false;
|
||||
float ext_factor = 0.0f;
|
||||
float theta_scale = 0.0f;
|
||||
float freq_scale = 0.0f;
|
||||
float attn_factor = 0.0f;
|
||||
bool is_neox = false;
|
||||
bool indep_sects = false;
|
||||
bool mrope_used = false;
|
||||
int sections[4] = { 0, 0, 0, 0 };
|
||||
bool is_imrope = false;
|
||||
};
|
||||
|
||||
struct ggml_cann_tensor_cache {
|
||||
@@ -474,7 +412,6 @@ struct ggml_backend_cann_context {
|
||||
ggml_cann_graph_lru_cache graph_lru_cache;
|
||||
bool acl_graph_mode = true;
|
||||
#endif
|
||||
cann_task_queue task_queue;
|
||||
bool async_mode;
|
||||
// Rope Cache
|
||||
ggml_cann_rope_cache rope_cache;
|
||||
@@ -488,15 +425,10 @@ struct ggml_backend_cann_context {
|
||||
* @brief Constructor for initializing the context with a given device.
|
||||
* @param device Device ID.
|
||||
*/
|
||||
explicit ggml_backend_cann_context(int device) :
|
||||
device(device),
|
||||
name("CANN" + std::to_string(device)),
|
||||
task_queue(1024, device) {
|
||||
explicit ggml_backend_cann_context(int device) : device(device), name("CANN" + std::to_string(device)) {
|
||||
ggml_cann_set_device(device);
|
||||
description = aclrtGetSocName();
|
||||
|
||||
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
|
||||
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF");
|
||||
#ifdef USE_ACL_GRAPH
|
||||
acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on"));
|
||||
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
|
||||
@@ -509,7 +441,6 @@ struct ggml_backend_cann_context {
|
||||
*/
|
||||
~ggml_backend_cann_context() {
|
||||
ggml_cann_set_device(device);
|
||||
task_queue.stop();
|
||||
if (copy_event != nullptr) {
|
||||
ACL_CHECK(aclrtDestroyEvent(copy_event));
|
||||
}
|
||||
|
||||
@@ -22,24 +22,24 @@
|
||||
|
||||
#include "ggml-cann.h"
|
||||
|
||||
#include <acl/acl.h>
|
||||
#include <stdarg.h>
|
||||
#include <aclnnop/aclnn_trans_matmul_weight.h>
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-cann/aclnn_ops.h"
|
||||
#include "ggml-cann/common.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <acl/acl.h>
|
||||
#include <aclnnop/aclnn_trans_matmul_weight.h>
|
||||
#include <stdarg.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <chrono>
|
||||
#include <unordered_set>
|
||||
#include <optional>
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-cann/aclnn_ops.h"
|
||||
#include "ggml-cann/common.h"
|
||||
#include "ggml.h"
|
||||
#include <queue>
|
||||
#include <unordered_set>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
|
||||
@@ -1177,19 +1177,18 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
|
||||
* across calls. This reduces overhead from repeated memory allocation and deallocation.
|
||||
*/
|
||||
static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) {
|
||||
aclTensor * weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
|
||||
uint64_t workspaceSize = 0;
|
||||
acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor * executor;
|
||||
|
||||
// TransMatmulWeight
|
||||
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, &workspaceSize, &executor));
|
||||
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor));
|
||||
// Avoid frequent malloc/free of the workspace.
|
||||
g_nz_workspaces[device].realloc(workspaceSize);
|
||||
|
||||
void * g_nz_workspace = g_nz_workspaces[device].get();
|
||||
|
||||
ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
|
||||
ACL_CHECK(aclDestroyTensor(weightTransposed));
|
||||
}
|
||||
|
||||
// TODO: need handle tensor which has paddings.
|
||||
@@ -1641,7 +1640,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
|
||||
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
||||
},
|
||||
/* .device = */
|
||||
ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
|
||||
ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
|
||||
/* .context = */ nullptr,
|
||||
};
|
||||
|
||||
@@ -1780,6 +1779,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
|
||||
case GGML_OP_L2_NORM:
|
||||
ggml_cann_l2_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
ggml_cann_cross_entropy_loss(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
ggml_cann_concat(ctx, dst);
|
||||
break;
|
||||
@@ -1884,6 +1886,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
ggml_cann_flash_attn_ext(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_OUT_PROD:
|
||||
ggml_cann_out_prod(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -1946,7 +1951,8 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
|
||||
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
|
||||
GGML_ASSERT(!ggml_is_quantized(tensor->type));
|
||||
|
||||
ggml_cann_async_memcpy(cann_ctx, (char *) tensor->data + offset, data, size, ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
ACL_CHECK(aclrtMemcpyAsync((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE,
|
||||
cann_ctx->stream()));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1971,7 +1977,8 @@ static void ggml_backend_cann_get_tensor_async(ggml_backend_t backend,
|
||||
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
|
||||
GGML_ASSERT(!ggml_is_quantized(tensor->type));
|
||||
|
||||
ggml_cann_async_memcpy(cann_ctx, data, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST);
|
||||
ACL_CHECK(aclrtMemcpyAsync(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST,
|
||||
cann_ctx->stream()));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -2032,7 +2039,6 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
|
||||
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
|
||||
|
||||
// wait for task_queue empty to keep task order.
|
||||
cann_ctx_src->task_queue.wait();
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
|
||||
cann_ctx_src->stream()));
|
||||
// record event on src stream after the copy
|
||||
@@ -2065,7 +2071,6 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
|
||||
*/
|
||||
static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
|
||||
ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
|
||||
cann_ctx->task_queue.wait();
|
||||
ggml_cann_set_device(cann_ctx->device);
|
||||
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
||||
}
|
||||
@@ -2244,8 +2249,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
||||
bool & use_cann_graph,
|
||||
bool & cann_graph_update_required) {
|
||||
#ifdef USE_ACL_GRAPH
|
||||
ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
|
||||
if (use_cann_graph && cann_graph_update_required) {
|
||||
if (use_cann_graph && cann_graph_update_required) { // Begin CANN graph capture
|
||||
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
|
||||
}
|
||||
#endif // USE_ACL_GRAPH
|
||||
@@ -2269,12 +2273,14 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
||||
}
|
||||
|
||||
#ifdef USE_ACL_GRAPH
|
||||
if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
|
||||
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
|
||||
}
|
||||
|
||||
if (use_cann_graph) {
|
||||
// Execute graph
|
||||
ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
|
||||
|
||||
if (cann_graph_update_required) { // End CANN graph capture
|
||||
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
|
||||
}
|
||||
|
||||
// Execute CANN graph
|
||||
ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream()));
|
||||
}
|
||||
#endif // USE_ACL_GRAPH
|
||||
@@ -2300,9 +2306,9 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
||||
// calculate rope cache for fist layer in current device.
|
||||
cann_ctx->rope_cache.cached = false;
|
||||
|
||||
bool cann_graph_update_required = false;
|
||||
#ifdef USE_ACL_GRAPH
|
||||
bool use_cann_graph = true;
|
||||
bool cann_graph_update_required = false;
|
||||
|
||||
static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
|
||||
if (!prefill_use_graph) {
|
||||
@@ -2333,7 +2339,6 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
||||
}
|
||||
#else
|
||||
bool use_cann_graph = false;
|
||||
bool cann_graph_update_required = false;
|
||||
#endif // USE_ACL_GRAPH
|
||||
evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required);
|
||||
|
||||
@@ -2475,11 +2480,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
return false;
|
||||
}
|
||||
|
||||
const int mode = ((const int32_t *) op->op_params)[2];
|
||||
if (mode & GGML_ROPE_TYPE_MROPE) {
|
||||
return false;
|
||||
}
|
||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
if (op->src[0]->ne[0] > 896) {
|
||||
return false;
|
||||
}
|
||||
#ifdef ASCEND_310P
|
||||
@@ -2499,6 +2500,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {
|
||||
return false;
|
||||
}
|
||||
if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_POOL_2D:
|
||||
@@ -2518,9 +2522,11 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
// value of paddingW should be at most half of kernelW
|
||||
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
|
||||
}
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_SUM:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_REPEAT:
|
||||
@@ -2556,6 +2562,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
return true;
|
||||
case GGML_OP_OUT_PROD:
|
||||
{
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
// TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
|
||||
return (op->src[0]->ne[0] - 1) <= 255;
|
||||
|
||||
@@ -145,26 +145,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
include(CheckCXXSourceRuns)
|
||||
|
||||
function(check_arm_feature tag code)
|
||||
macro(check_arm_feature tag feature code)
|
||||
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+${tag}")
|
||||
check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
|
||||
if (GGML_MACHINE_SUPPORTS_${tag})
|
||||
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}" PARENT_SCOPE)
|
||||
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}")
|
||||
else()
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+no${tag}")
|
||||
check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
|
||||
if (GGML_MACHINE_SUPPORTS_no${tag})
|
||||
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}" PARENT_SCOPE)
|
||||
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}")
|
||||
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_${feature})
|
||||
endif()
|
||||
endif()
|
||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||
endfunction()
|
||||
endmacro()
|
||||
|
||||
check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
|
||||
check_arm_feature(i8mm "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
|
||||
check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
|
||||
check_arm_feature(sme "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
|
||||
check_arm_feature(dotprod DOTPROD "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
|
||||
check_arm_feature(i8mm MATMUL_INT8 "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
|
||||
check_arm_feature(sve SVE "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
|
||||
check_arm_feature(sme SME "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
|
||||
|
||||
list(APPEND ARCH_FLAGS "${ARM_NATIVE_FLAG}${ARM_NATIVE_FLAG_FIX}")
|
||||
else()
|
||||
@@ -216,35 +217,28 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# show enabled features
|
||||
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
|
||||
set(FEAT_INPUT_FILE "NUL")
|
||||
else()
|
||||
set(FEAT_INPUT_FILE "/dev/null")
|
||||
endif()
|
||||
message(STATUS "Checking for ARM features using flags:")
|
||||
foreach(flag IN LISTS ARCH_FLAGS)
|
||||
message(STATUS " ${flag}")
|
||||
endforeach()
|
||||
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_C_COMPILER} ${ARCH_FLAGS} -dM -E -
|
||||
INPUT_FILE ${FEAT_INPUT_FILE}
|
||||
OUTPUT_VARIABLE ARM_FEATURE
|
||||
RESULT_VARIABLE ARM_FEATURE_RESULT
|
||||
)
|
||||
if (ARM_FEATURE_RESULT)
|
||||
message(WARNING "Failed to get ARM features")
|
||||
else()
|
||||
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
||||
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
|
||||
if (NOT ${feature_pos} EQUAL -1)
|
||||
# Special handling for MATMUL_INT8 when machine doesn't support i8mm
|
||||
if ("${feature}" STREQUAL "MATMUL_INT8" AND GGML_MACHINE_SUPPORTS_noi8mm)
|
||||
message(STATUS "ARM feature ${feature} detected but unsetting due to machine not supporting i8mm")
|
||||
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_MATMUL_INT8)
|
||||
else()
|
||||
message(STATUS "ARM feature ${feature} enabled")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
include(CheckCXXSourceCompiles)
|
||||
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||
string(REPLACE ";" " " ARCH_FLAGS_STR "${ARCH_FLAGS}")
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS_STR}")
|
||||
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
||||
set(ARM_FEATURE "HAVE_${feature}")
|
||||
check_cxx_source_compiles(
|
||||
"
|
||||
#if !defined(__ARM_FEATURE_${feature})
|
||||
# error \"Feature ${feature} is not defined\"
|
||||
#endif
|
||||
int main() { return 0; }
|
||||
"
|
||||
${ARM_FEATURE}
|
||||
)
|
||||
endforeach()
|
||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "x86")
|
||||
message(STATUS "x86 detected")
|
||||
@@ -399,9 +393,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
||||
|
||||
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10)
|
||||
elseif (EXTRACTED_NUMBER EQUAL 9)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9)
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
|
||||
else()
|
||||
@@ -459,22 +453,35 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
ggml-cpu/spacemit/ime_kernels.h
|
||||
)
|
||||
endif()
|
||||
set(MARCH_STR "rv64gc")
|
||||
if (GGML_RV_ZFH)
|
||||
string(APPEND MARCH_STR "_zfh")
|
||||
endif()
|
||||
if (GGML_XTHEADVECTOR)
|
||||
string(APPEND MARCH_STR "_xtheadvector")
|
||||
elseif (GGML_RVV)
|
||||
string(APPEND MARCH_STR "_v")
|
||||
if (GGML_RV_ZVFH)
|
||||
string(APPEND MARCH_STR "_zvfh")
|
||||
if(NOT GGML_CPU_ALL_VARIANTS)
|
||||
set(MARCH_STR "rv64gc")
|
||||
if (GGML_RV_ZFH)
|
||||
string(APPEND MARCH_STR "_zfh")
|
||||
endif()
|
||||
if (GGML_XTHEADVECTOR)
|
||||
string(APPEND MARCH_STR "_xtheadvector")
|
||||
elseif (GGML_RVV)
|
||||
string(APPEND MARCH_STR "_v")
|
||||
if (GGML_RV_ZVFH)
|
||||
string(APPEND MARCH_STR "_zvfh")
|
||||
endif()
|
||||
endif()
|
||||
if (GGML_RV_ZICBOP)
|
||||
string(APPEND MARCH_STR "_zicbop")
|
||||
endif()
|
||||
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
|
||||
else()
|
||||
# Begin with the lowest baseline
|
||||
set(ARCH_DEFINITIONS "")
|
||||
|
||||
if (GGML_INTERNAL_RVV)
|
||||
message(STATUS "RVV enabled")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_RVV)
|
||||
list(APPEND ARCH_FLAGS -march=rv64gc_v -mabi=lp64d)
|
||||
endif()
|
||||
|
||||
ggml_add_cpu_backend_features(${GGML_CPU_NAME} riscv ${ARCH_DEFINITIONS})
|
||||
endif()
|
||||
if (GGML_RV_ZICBOP)
|
||||
string(APPEND MARCH_STR "_zicbop")
|
||||
endif()
|
||||
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||
message(STATUS "s390x detected")
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
|
||||
@@ -33,10 +33,12 @@
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
@@ -44,27 +46,30 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#elif defined(__POWERPC__) || defined(__powerpc__)
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
|
||||
@@ -76,10 +81,12 @@
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
@@ -87,6 +94,7 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
@@ -101,10 +109,12 @@
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
@@ -112,6 +122,7 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
@@ -134,15 +145,18 @@
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
@@ -163,10 +177,12 @@
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
@@ -174,6 +190,7 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
@@ -196,10 +213,12 @@
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
@@ -207,6 +226,7 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
|
||||
@@ -24,6 +24,29 @@
|
||||
|
||||
#define UNUSED GGML_UNUSED
|
||||
|
||||
static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
|
||||
int16x8_t * out_mins,
|
||||
int8_t * out_scales) {
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
constexpr uint8_t scales_size = 12;
|
||||
|
||||
uint32_t sm[3];
|
||||
memcpy(sm, scales_in, scales_size);
|
||||
|
||||
const uint32_t mins_0_3 = sm[1] & kmask1;
|
||||
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
|
||||
const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
|
||||
|
||||
*out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));
|
||||
|
||||
uint32_t scales_u32[2];
|
||||
scales_u32[0] = sm[0] & kmask1;
|
||||
scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
|
||||
memcpy(out_scales, scales_u32, 8);
|
||||
}
|
||||
|
||||
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(QK8_0 == 32);
|
||||
assert(k % QK8_0 == 0);
|
||||
@@ -474,6 +497,295 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
|
||||
// 1x8 tile = 2 x 4
|
||||
float32x4_t acc_f32[col_groups];
|
||||
|
||||
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < col_groups; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
|
||||
float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
||||
float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d);
|
||||
float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d);
|
||||
float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
|
||||
float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
|
||||
float32x4_t sb_min_0123 = vmulq_f32(q4_dmin_0, q8_d);
|
||||
float32x4_t sb_min_4567 = vmulq_f32(q4_dmin_1, q8_d);
|
||||
|
||||
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
|
||||
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
||||
int32x4_t acc_lo[col_groups];
|
||||
int32x4_t acc_hi[col_groups];
|
||||
|
||||
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
||||
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
||||
int16_t bsums_arr[8];
|
||||
vst1q_s16(bsums_arr, bsums);
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
for (int i = 0; i < col_groups; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int16x8_t q4sb_mins[2];
|
||||
int16x8_t q4sb_scales[2];
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q4sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||
}
|
||||
|
||||
int8x16_t q8_qs[64 / 16];
|
||||
for (int i = 0; i < 64 / 16; i++) {
|
||||
q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
|
||||
}
|
||||
|
||||
for (int c = 0; c < col_groups; c++) {
|
||||
uint8x16_t q4_cols[8];
|
||||
for (int i = 0; i < 8; i++) {
|
||||
q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
|
||||
}
|
||||
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3);
|
||||
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3);
|
||||
}
|
||||
|
||||
// Scales
|
||||
// row c0123 blk0 and blk1
|
||||
const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
|
||||
const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
|
||||
const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
|
||||
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
|
||||
acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
|
||||
// row c4567 blk0 and blk1
|
||||
const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
|
||||
const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
|
||||
const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
|
||||
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
|
||||
acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
|
||||
|
||||
// Bias Correction
|
||||
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
||||
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
||||
|
||||
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
||||
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
||||
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
||||
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
||||
} // for sb
|
||||
|
||||
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
|
||||
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
|
||||
} // for b
|
||||
|
||||
int base = x * ncols_interleaved;
|
||||
vst1q_f32(s + base, acc_f32[0]);
|
||||
vst1q_f32(s + base + 4, acc_f32[1]);
|
||||
} // for x
|
||||
return;
|
||||
#endif // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int col_pairs = ncols_interleaved / 2;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
|
||||
// 1x8 tile = 2 x 4
|
||||
float32x4_t acc_f32[ncols_interleaved / 4];
|
||||
|
||||
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < ncols_interleaved / 4; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
|
||||
float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
||||
float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d);
|
||||
float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d);
|
||||
float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
|
||||
float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
|
||||
float32x4_t sb_min_0 = vmulq_f32(q4_dmin_0, q8_d);
|
||||
float32x4_t sb_min_1 = vmulq_f32(q4_dmin_1, q8_d);
|
||||
|
||||
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
|
||||
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
||||
// 2 sb each iteration
|
||||
int32x4_t acc_lo[col_pairs];
|
||||
int32x4_t acc_hi[col_pairs];
|
||||
|
||||
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
||||
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
||||
int16_t bsums_arr[8];
|
||||
vst1q_s16(bsums_arr, bsums);
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
for (int i = 0; i < col_pairs; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
|
||||
int16x8_t q4sb_scales[2];
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q4sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||
}
|
||||
|
||||
const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
|
||||
|
||||
// Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
|
||||
// but still need the qs to use the low and hi bits from q4
|
||||
const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
|
||||
int8x16_t q8_qs[8];
|
||||
for (int i = 0; i < 8; i++) {
|
||||
q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
|
||||
}
|
||||
|
||||
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
||||
for (int cp = 0; cp < col_pairs; cp++) {
|
||||
uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp);
|
||||
uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64);
|
||||
uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128);
|
||||
uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192);
|
||||
|
||||
acc_lo[cp] =
|
||||
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]); // 0 .. 7
|
||||
acc_lo[cp] =
|
||||
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]); // 8 ..15
|
||||
acc_lo[cp] =
|
||||
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]); // 16..23
|
||||
acc_lo[cp] =
|
||||
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]); // 24..31
|
||||
|
||||
acc_hi[cp] =
|
||||
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]); // 32..39
|
||||
acc_hi[cp] =
|
||||
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]); // 40..47
|
||||
acc_hi[cp] =
|
||||
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]); // 48..55
|
||||
acc_hi[cp] =
|
||||
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]); // 56..63
|
||||
}
|
||||
|
||||
// Iterates over a pair of column pairs (4 columns) to use a single 128 register
|
||||
// p = 0 -> 0123 p2 -> 4567
|
||||
for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
|
||||
int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]);
|
||||
int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]);
|
||||
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
|
||||
|
||||
// 0123 or 4567
|
||||
float32x4_t sumf_0 =
|
||||
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
|
||||
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
|
||||
|
||||
float32x4_t sumf_1 =
|
||||
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
|
||||
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
|
||||
}
|
||||
|
||||
// Multiply Acc bsum + mins
|
||||
// Each pair of subblocks share the same bsums
|
||||
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
|
||||
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
||||
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
||||
|
||||
// cols 0-3 bias
|
||||
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
||||
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
||||
|
||||
// cols 4-7 bias
|
||||
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
||||
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
||||
} // for sb
|
||||
|
||||
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0);
|
||||
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1);
|
||||
} // for b
|
||||
|
||||
int base = x * ncols_interleaved;
|
||||
vst1q_f32(s + base, acc_f32[0]);
|
||||
vst1q_f32(s + base + 4, acc_f32[1]);
|
||||
} // for x
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
@@ -1889,3 +2201,412 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 4;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
|
||||
// 8 accumulators: 2 row pairs × 4 col pairs
|
||||
float32x4_t acc_f32[acc_size];
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// d4 0 1 2 3, 4 5 6 7
|
||||
float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
|
||||
float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
|
||||
// d8 0 1 2 3
|
||||
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
|
||||
// mins
|
||||
float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
|
||||
float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
|
||||
|
||||
// Precomputation of scales and mins
|
||||
float32x4_t sbd_scale_0123[q8_k_blocklen];
|
||||
float32x4_t sbd_scale_4567[q8_k_blocklen];
|
||||
float32x4_t sbd_min_0123[q8_k_blocklen];
|
||||
float32x4_t sbd_min_4567[q8_k_blocklen];
|
||||
|
||||
sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
|
||||
sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
|
||||
sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
|
||||
sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
|
||||
|
||||
sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
|
||||
sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
|
||||
sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
|
||||
sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
|
||||
|
||||
sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
|
||||
sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
|
||||
sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
|
||||
sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
|
||||
|
||||
sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
|
||||
sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
|
||||
sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
|
||||
sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
|
||||
|
||||
// Precomputation of bsums, each vpaddq calcs all the bsums for each row
|
||||
const int16x8_t bsums[q8_k_blocklen] = {
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
int16_t bsums_arr[QK_K / 64][8];
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
||||
}
|
||||
|
||||
// interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
|
||||
int32x4_t bias_acc[acc_size];
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
bias_acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Int accumulators for qs vecdot (4 row x 2 col quartets)
|
||||
int32x4_t acc_lo[acc_size];
|
||||
int32x4_t acc_hi[acc_size];
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int16x8_t q4sb_scales[2];
|
||||
int16x8_t q4sb_mins[2];
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q4sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||
}
|
||||
|
||||
constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
|
||||
for (int k = 0; k < reads_per_sb; k++) {
|
||||
const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
|
||||
const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
|
||||
|
||||
// 0..3 & 32..35
|
||||
const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
|
||||
const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
|
||||
|
||||
const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
|
||||
const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
|
||||
|
||||
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
|
||||
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
|
||||
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
|
||||
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
|
||||
|
||||
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
|
||||
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
|
||||
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
|
||||
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
|
||||
|
||||
const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
|
||||
const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
|
||||
|
||||
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
|
||||
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
|
||||
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
|
||||
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
|
||||
|
||||
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
|
||||
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
|
||||
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
|
||||
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
|
||||
}
|
||||
|
||||
// Scale and bias application
|
||||
// acc is stored interleaved to match output layout
|
||||
const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
|
||||
const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
|
||||
const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
|
||||
const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
|
||||
for (int row = 0; row < q8_k_blocklen; row++) {
|
||||
// Bias correction
|
||||
// row c0123 blk0 and blk1
|
||||
const float32x4_t sumf_0123 =
|
||||
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
|
||||
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
|
||||
acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
|
||||
|
||||
// row c4567 blk0 and blk1
|
||||
const float32x4_t sumf_4567 =
|
||||
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
|
||||
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
|
||||
acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
|
||||
|
||||
// Bias
|
||||
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
|
||||
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
|
||||
|
||||
// row c0123 blk0 and blk1
|
||||
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
||||
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
||||
|
||||
// row c4567 blk0 and blk1
|
||||
bias_acc[2 * row + 1] =
|
||||
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
||||
bias_acc[2 * row + 1] =
|
||||
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
||||
}
|
||||
} // for sb
|
||||
|
||||
for (int row = 0; row < q8_k_blocklen; row++) {
|
||||
acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
|
||||
acc_f32[2 * row + 1] =
|
||||
vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
|
||||
}
|
||||
} // for b
|
||||
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
|
||||
// 8 accumulators: 2 row pairs × 4 col pairs
|
||||
float32x4_t acc_f32[blocklen];
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// bsums pairs belongs to the same q8_k subblock
|
||||
const int16x8_t bsums[4]{
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
int16_t bsums_arr[4][8];
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
||||
}
|
||||
|
||||
int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
|
||||
int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
|
||||
int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
|
||||
for (int i = 0; i < 8; i++) {
|
||||
acc[i] = vdupq_n_s32(0);
|
||||
bias_acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int8_t q4sb_scales[2][8];
|
||||
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
|
||||
for (int i = 0; i < 2; i++) {
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
|
||||
}
|
||||
|
||||
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
|
||||
const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
|
||||
|
||||
int8x16_t q8_qs_01[8];
|
||||
int8x16_t q8_qs_23[8];
|
||||
|
||||
// Load 32-byte per row pair, 1 subblock each time
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const int offset = i * 32; // 16 for row 01, 16 for row 23
|
||||
q8_qs_01[i] = vld1q_s8(q8_base + offset);
|
||||
q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
|
||||
}
|
||||
|
||||
const int8x16_t q8s[2][8] = {
|
||||
{ q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
|
||||
q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
|
||||
{ q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
|
||||
q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
|
||||
};
|
||||
|
||||
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
||||
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
sb_acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
|
||||
uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
|
||||
uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
|
||||
uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
|
||||
const int8x16_t q4_nibbles[2][4] = {
|
||||
{
|
||||
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
|
||||
},
|
||||
{
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
|
||||
}
|
||||
};
|
||||
|
||||
// Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
|
||||
// for each of the internal 32 qs subblock (blk)
|
||||
for (int rp = 0; rp < 2; rp++) {
|
||||
for (int blk = 0; blk < 2; blk++) {
|
||||
const int8x16_t * q8 = &q8s[rp][4 * blk];
|
||||
const int8x16_t * q4 = q4_nibbles[blk];
|
||||
int32x4_t acc = sb_acc[2 * rp + blk];
|
||||
// mul add for each qs in the same subblock
|
||||
for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
|
||||
acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
|
||||
}
|
||||
sb_acc[2 * rp + blk] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
// Scales[i] corresponds to column i
|
||||
const int scale_offset = cp * 2;
|
||||
for (int blk = 0; blk < 2; blk++) {
|
||||
const int32x4_t block_scale = {
|
||||
(int32_t) q4sb_scales[blk][scale_offset],
|
||||
(int32_t) q4sb_scales[blk][scale_offset],
|
||||
(int32_t) q4sb_scales[blk][scale_offset + 1],
|
||||
(int32_t) q4sb_scales[blk][scale_offset + 1],
|
||||
};
|
||||
acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
|
||||
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
|
||||
}
|
||||
}
|
||||
|
||||
// Multiply Acc bsum + mins
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
// Each pair of subblocks share the same bsums
|
||||
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
|
||||
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
|
||||
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
|
||||
|
||||
bias_acc[2 * q8_row] =
|
||||
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
||||
bias_acc[2 * q8_row] =
|
||||
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
||||
bias_acc[2 * q8_row + 1] =
|
||||
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
||||
bias_acc[2 * q8_row + 1] =
|
||||
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
||||
}
|
||||
} // for sb
|
||||
|
||||
// Reorder of i8mm output with bias and output layout
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
|
||||
acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
|
||||
}
|
||||
int32x4_t reorder_acc[8] = {
|
||||
vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
|
||||
vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
|
||||
vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
|
||||
vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
|
||||
vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
|
||||
vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
|
||||
vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
|
||||
vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
|
||||
};
|
||||
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
|
||||
float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
|
||||
const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
|
||||
|
||||
float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
|
||||
const float32x4_t scale = vmulq_f32(q4_d, q8_d);
|
||||
|
||||
acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
|
||||
acc_f32[2 * i + j] =
|
||||
vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
|
||||
}
|
||||
}
|
||||
} // for b
|
||||
|
||||
// With the previous reorder, the tile is already in the correct memory layout.
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
38
ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp
Normal file
38
ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp
Normal file
@@ -0,0 +1,38 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#if defined(__riscv) && __riscv_xlen == 64
|
||||
#include <asm/hwprobe.h>
|
||||
#include <asm/unistd.h>
|
||||
#include <unistd.h>
|
||||
|
||||
struct riscv64_features {
|
||||
bool has_rvv = false;
|
||||
|
||||
riscv64_features() {
|
||||
struct riscv_hwprobe probe;
|
||||
probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0;
|
||||
probe.value = 0;
|
||||
|
||||
int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0);
|
||||
|
||||
if (0 == ret) {
|
||||
has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static int ggml_backend_cpu_riscv64_score() {
|
||||
int score = 1;
|
||||
riscv64_features rf;
|
||||
|
||||
#ifdef GGML_USE_RVV
|
||||
if (!rf.has_rvv) { return 0; }
|
||||
score += 1 << 1;
|
||||
#endif
|
||||
|
||||
return score;
|
||||
}
|
||||
|
||||
GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_riscv64_score)
|
||||
|
||||
#endif // __riscv && __riscv_xlen == 64
|
||||
@@ -646,7 +646,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
||||
int64_t xstart = 0;
|
||||
int anr = nr - nr%16; // Used to align nr with boundary of 16
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
int anc = nc - nc%16; // Used to align nc with boundary of 16
|
||||
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
||||
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
||||
@@ -1041,7 +1041,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
xstart = anc/8;
|
||||
y = 0;
|
||||
}
|
||||
#endif // __AVX512F__
|
||||
#endif // __AVX512BW__ && __AVX512DQ__
|
||||
|
||||
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
||||
|
||||
@@ -1989,7 +1989,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
||||
int64_t xstart = 0;
|
||||
int anr = nr - nr % 16;; // Used to align nr with boundary of 16
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
int anc = nc - nc % 16; // Used to align nc with boundary of 16
|
||||
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
||||
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
||||
@@ -2727,7 +2727,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
xstart = anc/8;
|
||||
y = 0;
|
||||
}
|
||||
#endif //AVX512F
|
||||
#endif // __AVX512BW__ && __AVX512DQ__
|
||||
|
||||
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
||||
for (; y < anr / 4; y += 4) {
|
||||
@@ -3467,7 +3467,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
__m256i scalesmask2 = _mm256_castsi128_si256(scalesmask2_sse);
|
||||
scalesmask2 = _mm256_permute2f128_si256(scalesmask2, scalesmask2, 0);
|
||||
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
|
||||
int anc = nc - nc % 16; // Used to align nc with boundary of 16
|
||||
|
||||
@@ -4947,7 +4947,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
y = 0;
|
||||
}
|
||||
|
||||
#endif //AVX512F
|
||||
#endif // __AVX512BW__ && __AVX512DQ__
|
||||
|
||||
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
||||
for (; y < anr / 4; y += 4) {
|
||||
|
||||
@@ -1731,6 +1731,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_sum_rows(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_CUMSUM:
|
||||
{
|
||||
ggml_compute_forward_cumsum(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
ggml_compute_forward_mean(params, tensor);
|
||||
@@ -1923,10 +1927,22 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_argsort(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_TOP_K:
|
||||
{
|
||||
ggml_compute_forward_top_k(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
ggml_compute_forward_leaky_relu(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_TRI:
|
||||
{
|
||||
ggml_compute_forward_tri(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_FILL:
|
||||
{
|
||||
ggml_compute_forward_fill(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
ggml_compute_forward_flash_attn_ext(params, tensor);
|
||||
@@ -1982,6 +1998,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
ggml_compute_forward_solve_tri(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MAP_CUSTOM1:
|
||||
{
|
||||
ggml_compute_forward_map_custom1(params, tensor);
|
||||
@@ -2140,6 +2160,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_ADD_ID:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_FILL:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
@@ -2157,6 +2180,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
@@ -2179,6 +2203,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
@@ -2289,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_TOP_K:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
case GGML_OP_SSM_CONV:
|
||||
@@ -2812,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan(
|
||||
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
||||
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
||||
} break;
|
||||
case GGML_OP_TOP_K:
|
||||
{
|
||||
cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const int64_t ne10 = node->src[1]->ne[0]; // DK
|
||||
|
||||
@@ -39,7 +39,7 @@
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
||||
#define NELEMS(x) (sizeof(x) / sizeof(*x))
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t)>
|
||||
static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
|
||||
@@ -635,6 +635,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
},
|
||||
#endif
|
||||
#endif
|
||||
{ /* Sentinel */ }
|
||||
};
|
||||
|
||||
static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
||||
@@ -803,6 +804,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
{ /* Sentinel */ }
|
||||
};
|
||||
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
||||
@@ -810,7 +812,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
|
||||
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
|
||||
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
|
||||
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
|
||||
@@ -820,7 +822,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
||||
}
|
||||
}
|
||||
if (!kernel) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
|
||||
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
|
||||
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
|
||||
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
|
||||
@@ -830,6 +832,10 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(gemm_gemv_kernels);
|
||||
GGML_UNUSED(gemm_gemv_kernels_q8);
|
||||
GGML_UNUSED(cpu_features);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -840,12 +846,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
|
||||
ggml_kleidiai_kernels * kernels = nullptr;
|
||||
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
|
||||
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
|
||||
kernels = &gemm_gemv_kernels[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(features);
|
||||
#endif
|
||||
|
||||
return kernels;
|
||||
@@ -855,12 +863,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features)
|
||||
ggml_kleidiai_kernels * kernels = nullptr;
|
||||
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
|
||||
if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
|
||||
kernels = &gemm_gemv_kernels_q8[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(features);
|
||||
#endif
|
||||
|
||||
return kernels;
|
||||
|
||||
@@ -7,8 +7,10 @@
|
||||
#include "unary-ops.h"
|
||||
#include "vec.h"
|
||||
|
||||
#include <float.h>
|
||||
#include <cfloat>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
|
||||
// ggml_compute_forward_dup
|
||||
|
||||
@@ -1394,6 +1396,56 @@ void ggml_compute_forward_sum(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_cumsum
|
||||
|
||||
static void ggml_compute_forward_cumsum_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(ne0 == ne00);
|
||||
GGML_ASSERT(ne1 == ne01);
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
GGML_ASSERT(ne3 == ne03);
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
ggml_vec_cumsum_f32(ne00, dst_row, src_row);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_cumsum(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_cumsum_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_sum_rows
|
||||
|
||||
static void ggml_compute_forward_sum_rows_f32(
|
||||
@@ -2140,6 +2192,83 @@ static void ggml_compute_forward_gelu(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_fill
|
||||
|
||||
static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const float c = ggml_get_op_params_f32(dst, 0);
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, dst);
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne2*ne1);
|
||||
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
|
||||
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
|
||||
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
||||
|
||||
ggml_vec_set_f32(ne0, dst_ptr, c);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
ggml_compute_forward_fill_f32(params, dst);
|
||||
}
|
||||
|
||||
// ggml_compute_tri
|
||||
|
||||
static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
|
||||
bool (*bipred)(int, int);
|
||||
|
||||
switch (ttype) {
|
||||
case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
|
||||
case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
|
||||
case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
|
||||
case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
|
||||
default: GGML_ABORT("invalid tri type");
|
||||
}
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
||||
|
||||
for (int i0 = 0; i0 < ne0; ++i0) {
|
||||
dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_tri_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_gelu_erf
|
||||
|
||||
static void ggml_compute_forward_gelu_erf_f32(
|
||||
@@ -7291,6 +7420,65 @@ static void ggml_compute_forward_upscale_f32(
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
|
||||
// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
|
||||
// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
|
||||
auto triangle_filter = [](float x) -> float {
|
||||
return std::max(1.0f - fabsf(x), 0.0f);
|
||||
};
|
||||
|
||||
// support and invscale, minimum 1 pixel for bilinear
|
||||
const float support1 = std::max(1.0f, 1.0f / sf1);
|
||||
const float invscale1 = 1.0f / support1;
|
||||
const float support0 = std::max(1.0f, 1.0f / sf0);
|
||||
const float invscale0 = 1.0f / support0;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
const int64_t i03 = i3 / sf3;
|
||||
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
||||
const int64_t i02 = i2 / sf2;
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
const float y = ((float) i1 + pixel_offset) / sf1;
|
||||
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
||||
const float x = ((float) i0 + pixel_offset) / sf0;
|
||||
|
||||
// the range of source pixels that contribute
|
||||
const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
|
||||
const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
|
||||
const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
|
||||
const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
|
||||
|
||||
// bilinear filter with antialiasing
|
||||
float val = 0.0f;
|
||||
float total_weight = 0.0f;
|
||||
|
||||
for (int64_t sy = y_min; sy < y_max; sy++) {
|
||||
const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
|
||||
|
||||
for (int64_t sx = x_min; sx < x_max; sx++) {
|
||||
const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
|
||||
const float weight = weight_x * weight_y;
|
||||
|
||||
if (weight <= 0.0f) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
|
||||
val += pixel * weight;
|
||||
total_weight += weight;
|
||||
}
|
||||
}
|
||||
|
||||
if (total_weight > 0.0f) {
|
||||
val /= total_weight;
|
||||
}
|
||||
|
||||
float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
*dst_ptr = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
const int64_t i03 = i3 / sf3;
|
||||
@@ -7664,6 +7852,18 @@ void ggml_compute_forward_timestep_embedding(
|
||||
|
||||
// ggml_compute_forward_argsort
|
||||
|
||||
template<enum ggml_sort_order order>
|
||||
struct cmp_argsort {
|
||||
const float * data;
|
||||
bool operator()(int32_t a, int32_t b) const {
|
||||
if constexpr (order == GGML_SORT_ORDER_ASC) {
|
||||
return data[a] < data[b];
|
||||
} else {
|
||||
return data[a] > data[b];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static void ggml_compute_forward_argsort_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
@@ -7682,23 +7882,25 @@ static void ggml_compute_forward_argsort_f32(
|
||||
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
for (int64_t i = ith; i < nr; i += nth) {
|
||||
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
||||
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
||||
|
||||
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
||||
|
||||
for (int64_t j = 0; j < ne0; j++) {
|
||||
dst_data[j] = j;
|
||||
}
|
||||
|
||||
// C doesn't have a functional sort, so we do a bubble sort instead
|
||||
for (int64_t j = 0; j < ne0; j++) {
|
||||
for (int64_t k = j + 1; k < ne0; k++) {
|
||||
if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
|
||||
(order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
|
||||
int32_t tmp = dst_data[j];
|
||||
dst_data[j] = dst_data[k];
|
||||
dst_data[k] = tmp;
|
||||
}
|
||||
}
|
||||
switch (order) {
|
||||
case GGML_SORT_ORDER_ASC:
|
||||
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
|
||||
break;
|
||||
|
||||
case GGML_SORT_ORDER_DESC:
|
||||
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
|
||||
break;
|
||||
|
||||
default:
|
||||
GGML_ABORT("invalid sort order");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7721,6 +7923,72 @@ void ggml_compute_forward_argsort(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_top_k
|
||||
|
||||
struct cmp_top_k {
|
||||
const float * data;
|
||||
bool operator()(int32_t a, int32_t b) const {
|
||||
return data[a] > data[b];
|
||||
}
|
||||
};
|
||||
|
||||
static void ggml_compute_forward_top_k_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
|
||||
const int top_k = ne0;
|
||||
|
||||
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
||||
|
||||
for (int64_t i = ith; i < nr; i += nth) {
|
||||
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
||||
|
||||
for (int64_t j = 0; j < ne00; j++) {
|
||||
tmp[j] = j;
|
||||
}
|
||||
|
||||
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
|
||||
|
||||
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
||||
|
||||
std::copy(tmp, tmp + top_k, dst_data);
|
||||
|
||||
// emphasize that the order is not important
|
||||
if (top_k > 1) {
|
||||
std::swap(dst_data[0], dst_data[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_top_k(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_top_k_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_flash_attn_ext
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
||||
@@ -8521,7 +8789,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
// n_head
|
||||
for (int h = ih0; h < ih1; ++h) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||
const float dt_soft_plus = ggml_softplus(dt[h]);
|
||||
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
||||
const float dA = expf(dt_soft_plus * A[h]);
|
||||
const int g = h / (nh / ng); // repeat_interleave
|
||||
|
||||
@@ -8618,7 +8886,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
// n_head
|
||||
for (int h = ih0; h < ih1; ++h) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||
const float dt_soft_plus = ggml_softplus(dt[h]);
|
||||
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
||||
const int g = h / (nh / ng); // repeat_interleave
|
||||
|
||||
// dim
|
||||
@@ -8901,6 +9169,14 @@ void ggml_compute_forward_unary(
|
||||
{
|
||||
ggml_compute_forward_xielu(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
{
|
||||
ggml_compute_forward_expm1(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
{
|
||||
ggml_compute_forward_softplus(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -9497,6 +9773,76 @@ void ggml_compute_forward_gla(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ne00 == ne01); // A must be square
|
||||
GGML_ASSERT(ne0 == ne10); // solution cols == B cols
|
||||
GGML_ASSERT(ne1 == ne11); // solution rows == B rows
|
||||
|
||||
GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
|
||||
GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t k = ne10; // number of RHS columns
|
||||
const int64_t n = ne11; // A is n×n
|
||||
const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
|
||||
|
||||
// chunks per thread
|
||||
const int64_t dr = (nr + nth - 1)/nth;
|
||||
|
||||
// chunk range for this thread
|
||||
const int64_t ir0 = dr*ith;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
const float * A = (const float *) src0->data; // [n, n, B1, B2]
|
||||
const float * B = (const float *) src1->data; // [n, k, B1, B2]
|
||||
float * X = ( float *) dst->data; // [n, k, B1, B2]
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*k);
|
||||
const int64_t i02 = (ir - i03*ne02*k)/k;
|
||||
const int64_t i01 = (ir - i03*ne02*k - i02*k);
|
||||
|
||||
const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
|
||||
const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
|
||||
|
||||
float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
|
||||
|
||||
for (int64_t i00 = 0; i00 < n; ++i00) {
|
||||
float sum = 0.0f;
|
||||
for (int64_t t = 0; t < i00; ++t) {
|
||||
sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
|
||||
}
|
||||
|
||||
const float diag = A_batch[i00 * n + i00];
|
||||
assert(diag != 0.0f && "Zero diagonal in triangular matrix");
|
||||
|
||||
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_compute_forward_solve_tri_f32(params, dst);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_rwkv_wkv7
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv7_f32(
|
||||
|
||||
@@ -34,6 +34,7 @@ void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct
|
||||
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cumsum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
@@ -80,7 +81,10 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct
|
||||
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_flash_attn_back(
|
||||
const struct ggml_compute_params * params,
|
||||
@@ -96,6 +100,7 @@ void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params,
|
||||
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
@@ -124,6 +124,58 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(QK_K == 256);
|
||||
assert(k % QK_K == 0);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||
|
||||
// scalar
|
||||
const int blck_size_interleave = 4;
|
||||
float srcv[4][QK_K];
|
||||
float iscale[4];
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0;
|
||||
|
||||
for (int j = 0; j < QK_K; j++) {
|
||||
srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
|
||||
// Update the maximum value of the corresponding super block
|
||||
if(amax < fabsf(srcv[row_iter][j])) {
|
||||
amax = fabsf(srcv[row_iter][j]);
|
||||
max = srcv[row_iter][j];
|
||||
}
|
||||
}
|
||||
|
||||
iscale[row_iter] = amax ? -127.f/max : 0;
|
||||
|
||||
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
|
||||
}
|
||||
|
||||
for (int j = 0; j < QK_K / 4; j++) {
|
||||
y[i].bsums[j] = 0;
|
||||
}
|
||||
|
||||
// Quants values are interleaved in sequence of four bytes from corresponding super blocks
|
||||
// Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
|
||||
// i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
|
||||
for (int j = 0; j < QK_K * 4; j++) {
|
||||
int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
|
||||
int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
|
||||
src_offset += (j % blck_size_interleave);
|
||||
int index = (((j & 15) >> 2) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
|
||||
|
||||
float x0 = srcv[src_id][src_offset] * iscale[src_id];
|
||||
y[i].qs[j] = nearest_int(x0);
|
||||
y[i].bsums[index] += y[i].qs[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(QK_K == 256);
|
||||
assert(k % QK_K == 0);
|
||||
@@ -192,6 +244,12 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTR
|
||||
ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
|
||||
}
|
||||
|
||||
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
||||
assert(nrow == 4);
|
||||
UNUSED(nrow);
|
||||
ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);
|
||||
}
|
||||
|
||||
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
||||
assert(nrow == 4);
|
||||
UNUSED(nrow);
|
||||
@@ -333,6 +391,77 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 4;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
float sum_minf[8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0;
|
||||
sum_minf[j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
|
||||
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
|
||||
sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
@@ -727,6 +856,89 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 4;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nr % 4 == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
float sumf[4][8];
|
||||
float sum_minf[4][8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0;
|
||||
sum_minf[m][j] = 0.0;
|
||||
}
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
|
||||
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
|
||||
sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for(int m = 0; m < 4; m++) {
|
||||
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
||||
for(int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
@@ -1228,9 +1440,10 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
|
||||
|
||||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
GGML_ASSERT(interleave_block == 8 || interleave_block == 4);
|
||||
constexpr int nrows_interleaved = 8;
|
||||
|
||||
block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
|
||||
@@ -1468,6 +1681,10 @@ template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * da
|
||||
return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q4_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q4_K_to_q4_K_8_bl(t, 4, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
@@ -1501,6 +1718,10 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
||||
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -1529,6 +1750,10 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
||||
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -1600,29 +1825,52 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
return false;
|
||||
}
|
||||
|
||||
void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
|
||||
void forward_mul_mat_one_chunk(ggml_compute_params * params,
|
||||
ggml_tensor * op,
|
||||
int64_t src0_start,
|
||||
int64_t src0_end,
|
||||
int64_t src1_start,
|
||||
int64_t src1_end) {
|
||||
const ggml_tensor * src0 = op->src[0];
|
||||
const ggml_tensor * src1 = op->src[1];
|
||||
ggml_tensor * dst = op;
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const void * src1_wdata = params->wdata;
|
||||
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
|
||||
|
||||
GGML_ASSERT(ne03 == 1 && ne13 == 1);
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
const int64_t r2 = ne12 / ne02;
|
||||
|
||||
const int64_t i12 = src1_start / ne1;
|
||||
const int64_t i11 = src1_start - i12 * ne1;
|
||||
|
||||
// Determine batch index
|
||||
const int64_t i02 = i12 / r2;
|
||||
|
||||
const int64_t i1 = i11;
|
||||
const int64_t i2 = i12;
|
||||
|
||||
const char * src0_ptr = (const char *) src0->data + i02 * nb02;
|
||||
const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
|
||||
char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
|
||||
|
||||
const int64_t nrows = src1_end - src1_start;
|
||||
const int64_t ncols = src0_end - src0_start;
|
||||
|
||||
GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
|
||||
|
||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||
if (ne11 > 3) {
|
||||
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||
(float *) ((char *) dst->data) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
||||
if (nrows > 3) {
|
||||
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
|
||||
src0_ptr + src0_start * nb01, src1_ptr,
|
||||
nrows - (nrows % 4), ncols);
|
||||
}
|
||||
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||
src0_end - src0_start);
|
||||
for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
|
||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
|
||||
ne01, src0_ptr + src0_start * nb01,
|
||||
src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1647,6 +1895,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
// TODO: General batched mul mat for 4D tensors
|
||||
// Currently only supports 3D tensors
|
||||
GGML_ASSERT(ne03 == 1);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
GGML_ASSERT(ne3 == 1);
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
|
||||
@@ -1654,47 +1908,65 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
|
||||
char * wdata = static_cast<char *>(params->wdata);
|
||||
const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
|
||||
const size_t nbw2 = nbw1 * ne11;
|
||||
|
||||
assert(params->wsize >= nbw1 * ne11);
|
||||
assert(params->wsize >= nbw2 * ne12);
|
||||
|
||||
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
|
||||
|
||||
int64_t i11_processed = 0;
|
||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
|
||||
}
|
||||
// INFO: Quantization is done in planes to avoid extra complexity in chunking.
|
||||
// Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how
|
||||
// the planes are broadcast.
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
char * data_ptr = (char *) src1->data + i12 * nb12;
|
||||
char * wdata_ptr = wdata + i12 * nbw2;
|
||||
|
||||
i11_processed = ne11 - ne11 % 4;
|
||||
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
||||
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
|
||||
(void *) (wdata_ptr + i11 * nbw1), 4, ne10);
|
||||
}
|
||||
|
||||
const int64_t i11_processed = ne11 - ne11 % 4;
|
||||
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
||||
from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
|
||||
}
|
||||
}
|
||||
|
||||
// disable for NUMA
|
||||
const bool disable_chunking = ggml_is_numa();
|
||||
|
||||
// 4x chunks per thread
|
||||
int64_t nr = ggml_nrows(op->src[0]);
|
||||
int nth_scaled = nth * 4;
|
||||
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
||||
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
||||
const int64_t nr0 = ggml_nrows(op->src[0]);
|
||||
|
||||
int nth_scaled = nth * 4;
|
||||
int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
|
||||
int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
|
||||
|
||||
// src1 is chunked only by full planes.
|
||||
// When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE
|
||||
// to route them thorugh GEMV.
|
||||
// nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors
|
||||
// to avoid affecting their performance
|
||||
int64_t nchunk1 = ne12;
|
||||
|
||||
// Ensure minimum chunk size to avoid alignment issues with high thread counts
|
||||
// Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
|
||||
const int64_t min_chunk_size = NB_COLS;
|
||||
if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) {
|
||||
nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
|
||||
if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
|
||||
nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
||||
}
|
||||
|
||||
if (nth == 1 || nchunk < nth || disable_chunking) {
|
||||
nchunk = nth;
|
||||
int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
||||
// Only increase nchunk0 to nth if it won't make chunks too small
|
||||
if (nth == 1 || ((nchunk0 < nth || disable_chunking) && (nr0 + nth - 1) / nth >= min_chunk_size)) {
|
||||
nchunk0 = nth;
|
||||
dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
||||
}
|
||||
|
||||
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
|
||||
// This prevents creating too many tiny chunks that could overlap after alignment
|
||||
const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
|
||||
if (nchunk > max_nchunk) {
|
||||
nchunk = max_nchunk;
|
||||
}
|
||||
const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
||||
nchunk0 = MIN(nchunk0, max_nchunk);
|
||||
|
||||
if (ith == 0) {
|
||||
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||
@@ -1706,23 +1978,30 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
||||
while (current_chunk < nchunk) {
|
||||
int64_t src0_start = (current_chunk * ne01) / nchunk;
|
||||
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
|
||||
while (current_chunk < nchunk0 * nchunk1) {
|
||||
const int64_t ith0 = current_chunk % nchunk0;
|
||||
const int64_t ith1 = current_chunk / nchunk0;
|
||||
|
||||
int64_t src0_start = dr0 * ith0;
|
||||
int64_t src0_end = MIN(src0_start + dr0, nr0);
|
||||
|
||||
// full-plane range for src1
|
||||
int64_t src1_start = ith1 * ne11;
|
||||
int64_t src1_end = (ith1 + 1) * ne11;
|
||||
|
||||
// Align boundaries to NB_COLS - round up to ensure all data is included
|
||||
// The chunk size limiting above ensures chunks are large enough to prevent overlaps
|
||||
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
||||
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
||||
if (src0_end > ne01) {
|
||||
src0_end = ne01;
|
||||
}
|
||||
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
||||
src0_end = MIN(src0_end, ne01);
|
||||
|
||||
// Make sure current plane is the last one before exiting
|
||||
if (src0_start >= src0_end) {
|
||||
break;
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
continue;
|
||||
}
|
||||
|
||||
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end);
|
||||
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
|
||||
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
}
|
||||
@@ -1877,6 +2156,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
|
||||
|
||||
// instance for Q4_K
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
||||
|
||||
// instance for Q2
|
||||
@@ -1908,6 +2190,16 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
return &q4_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q4_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q4_K_8x4_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_Q2_K) {
|
||||
if (ggml_cpu_has_avx512()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
|
||||
@@ -80,10 +80,12 @@ extern "C" {
|
||||
|
||||
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -91,6 +93,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -99,10 +102,12 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
// Native implementations
|
||||
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -110,6 +115,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
||||
@@ -160,18 +160,18 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||
#define GGML_F32xt svfloat32_t
|
||||
#define GGML_F32xt_ZERO svdup_n_f32(0.0f)
|
||||
#define GGML_F32xt_SET1(x) svdup_n_f32(x)
|
||||
#define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a)
|
||||
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
|
||||
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_LOAD_IMPL(pg, a) svld1_f32(pg, a)
|
||||
#define GGML_F32xt_LOAD(a) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, a)
|
||||
#define GGML_F32xt_STORE_IMPL(pg, a, b) svst1_f32(pg, a, b)
|
||||
#define GGML_F32xt_STORE(a, b) GGML_F32xt_STORE_IMPL(DEFAULT_PG, a, b)
|
||||
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
|
||||
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_FMA(a, b, c) GGML_F32xt_FMA_IMPL(DEFAULT_PG, a, b, c)
|
||||
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
|
||||
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_ADD(a, b) GGML_F32xt_ADD_IMPL(DEFAULT_PG, a, b)
|
||||
#define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b)
|
||||
#define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_MUL(a, b) GGML_F32xt_MUL_IMPL(DEFAULT_PG, a, b)
|
||||
#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a)
|
||||
#define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_REDUCE_ONE(a) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, a)
|
||||
#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
|
||||
{ \
|
||||
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \
|
||||
@@ -183,7 +183,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \
|
||||
(res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \
|
||||
}
|
||||
#define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_REDUCE(res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
|
||||
GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8)
|
||||
|
||||
#define GGML_F32_VEC GGML_F32xt
|
||||
#define GGML_F32_VEC_ZERO GGML_F32xt_ZERO
|
||||
@@ -206,11 +207,11 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||
#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
|
||||
|
||||
#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
|
||||
#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F32Cxt_FMA(a, b, c) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, a, b, c)
|
||||
#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
|
||||
#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F32Cxt_ADD(a, b) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, a, b)
|
||||
#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
|
||||
#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F32Cxt_MUL(a, b) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, a, b)
|
||||
#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
|
||||
|
||||
#define GGML_F16x_VEC GGML_F32Cxt
|
||||
@@ -224,7 +225,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||
#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
|
||||
|
||||
#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
|
||||
#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F16xt_REDUCE_ONE(a) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, a)
|
||||
|
||||
#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
|
||||
{ \
|
||||
@@ -234,7 +235,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||
__fp16 sum_f16 = svaddv_f16(pg16, sum1); \
|
||||
(res) = (ggml_float) sum_f16; \
|
||||
}
|
||||
#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F16xt_REDUCE_MIXED(res, sum1, sum2, sum3, sum4) \
|
||||
GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, res, sum1, sum2, sum3, sum4)
|
||||
|
||||
// F16 NEON
|
||||
|
||||
|
||||
@@ -73,6 +73,14 @@ static inline float op_log(float x) {
|
||||
return logf(x);
|
||||
}
|
||||
|
||||
static inline float op_expm1(float x) {
|
||||
return expf(x) - 1.0f;
|
||||
}
|
||||
|
||||
static inline float op_softplus(float x) {
|
||||
return (x > 20.0f) ? x : logf(1.0f + expf(x));
|
||||
}
|
||||
|
||||
static inline float op_floor(float x) {
|
||||
return floorf(x);
|
||||
}
|
||||
@@ -290,6 +298,14 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
|
||||
unary_op<op_log>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_expm1(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_expm1>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_softplus>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_floor>(params, dst);
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
|
||||
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_expm1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_softplus(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
@@ -360,6 +360,13 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
||||
for (; i + 3 < n; i += 4) {
|
||||
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
for (int vl; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e32m2(n - i);
|
||||
vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
|
||||
vfloat32m2_t vy = ggml_v_silu_m2(vx, vl);
|
||||
__riscv_vse32_v_f32m2(&y[i], vy, vl);
|
||||
}
|
||||
#endif
|
||||
for (; i < n; ++i) {
|
||||
y[i] = ggml_silu_f32(x[i]);
|
||||
@@ -460,6 +467,16 @@ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const floa
|
||||
val = vec_mul(val, val);
|
||||
sum += (ggml_float)vec_hsum_f32x4(val);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
|
||||
for (int vl; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e32m2(n - i);
|
||||
vfloat32m2_t val = __riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], vl), mean, vl);
|
||||
__riscv_vse32_v_f32m2(&y[i], val, vl);
|
||||
val = __riscv_vfmul_vv_f32m2(val, val, vl);
|
||||
vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, vl);
|
||||
}
|
||||
sum = (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
|
||||
#endif
|
||||
for (; i < n; ++i) {
|
||||
float val = x[i] - mean;
|
||||
|
||||
@@ -397,119 +397,118 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
|
||||
}
|
||||
|
||||
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
|
||||
#if defined(GGML_SIMD)
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
const int sve_register_length = svcntb() * 8;
|
||||
const int ggml_f16_epr = sve_register_length / 16;
|
||||
const int ggml_f16_step = 8 * ggml_f16_epr;
|
||||
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
|
||||
const int sve_register_length = svcntb() * 8;
|
||||
const int ggml_f16_epr = sve_register_length / 16;
|
||||
const int ggml_f16_step = 8 * ggml_f16_epr;
|
||||
|
||||
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
||||
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
||||
|
||||
const int np= (n & ~(ggml_f16_step - 1));
|
||||
int np = (n & ~(ggml_f16_step - 1));
|
||||
|
||||
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
|
||||
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
|
||||
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
|
||||
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
|
||||
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
|
||||
|
||||
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
|
||||
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
|
||||
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
|
||||
|
||||
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
|
||||
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
|
||||
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
|
||||
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
|
||||
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
|
||||
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
|
||||
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
|
||||
|
||||
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
|
||||
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
|
||||
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
|
||||
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
|
||||
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
|
||||
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
|
||||
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
|
||||
|
||||
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
|
||||
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
|
||||
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
|
||||
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
|
||||
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
|
||||
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
|
||||
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
|
||||
|
||||
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
|
||||
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
|
||||
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
|
||||
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
|
||||
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
|
||||
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
|
||||
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
|
||||
|
||||
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
|
||||
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
|
||||
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
|
||||
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
|
||||
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
|
||||
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
|
||||
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
|
||||
|
||||
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
|
||||
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
|
||||
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
|
||||
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
|
||||
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
|
||||
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
|
||||
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
|
||||
}
|
||||
const int np2 = (n & ~(ggml_f16_epr - 1));
|
||||
for (int k = np; k < np2; k += ggml_f16_epr) {
|
||||
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
|
||||
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
|
||||
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + k, ry, 0);
|
||||
}
|
||||
|
||||
if (np2 < n) {
|
||||
svbool_t pg = svwhilelt_b16(np2, n);
|
||||
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
|
||||
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
|
||||
hy = svmad_f16_x(pg, hx, vx, hy);
|
||||
svst1_f16(pg, (__fp16 *)(y + np2), hy);
|
||||
}
|
||||
np = n;
|
||||
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
|
||||
const int np = n;
|
||||
_Float16 hv = (_Float16)v;
|
||||
for (int i = 0, avl; i < n; i += avl) {
|
||||
avl = __riscv_vsetvl_e16m8(n - i);
|
||||
vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl);
|
||||
vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl);
|
||||
vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl);
|
||||
__riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl);
|
||||
}
|
||||
#elif defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
|
||||
GGML_F16_VEC ax[GGML_F16_ARR];
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
}
|
||||
const int np2 = (n & ~(ggml_f16_epr - 1));
|
||||
for (int k = np; k < np2; k += ggml_f16_epr) {
|
||||
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
|
||||
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
|
||||
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + k, ry, 0);
|
||||
}
|
||||
|
||||
if (np2 < n) {
|
||||
svbool_t pg = svwhilelt_b16(np2, n);
|
||||
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
|
||||
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
|
||||
hy = svmad_f16_x(pg, hx, vx, hy);
|
||||
svst1_f16(pg, (__fp16 *)(y + np2), hy);
|
||||
}
|
||||
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
// todo: RVV impl
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
|
||||
}
|
||||
#else
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
|
||||
GGML_F16_VEC ax[GGML_F16_ARR];
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
const int np = 0;
|
||||
#endif
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// xs and vs are byte strides of x and v
|
||||
@@ -698,60 +697,61 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||
}
|
||||
|
||||
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
|
||||
#if defined(GGML_SIMD)
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
const int sve_register_length = svcntb() * 8;
|
||||
const int ggml_f16_epr = sve_register_length / 16;
|
||||
const int ggml_f16_step = 2 * ggml_f16_epr;
|
||||
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
|
||||
const int sve_register_length = svcntb() * 8;
|
||||
const int ggml_f16_epr = sve_register_length / 16;
|
||||
const int ggml_f16_step = 2 * ggml_f16_epr;
|
||||
|
||||
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
||||
const int np = (n & ~(ggml_f16_step - 1));
|
||||
svfloat16_t ay1, ay2;
|
||||
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
||||
const int np = (n & ~(ggml_f16_step - 1));
|
||||
svfloat16_t ay1, ay2;
|
||||
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_MUL(ay1, vx);
|
||||
GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_MUL(ay1, vx);
|
||||
GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
|
||||
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_MUL(ay2, vx);
|
||||
GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_MUL(ay2, vx);
|
||||
GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
|
||||
}
|
||||
// leftovers
|
||||
// maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
|
||||
if (np < n) {
|
||||
svbool_t pg = svwhilelt_b16(np, n);
|
||||
svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
|
||||
svfloat16_t out = svmul_f16_m(pg, hy, vx);
|
||||
svst1_f16(pg, (__fp16 *)(y + np), out);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
|
||||
for (int i = 0, vl; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m2(n - i);
|
||||
vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl);
|
||||
vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl);
|
||||
vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl);
|
||||
vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl);
|
||||
__riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl);
|
||||
}
|
||||
#elif defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
}
|
||||
// leftovers
|
||||
// maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
|
||||
if (np < n) {
|
||||
svbool_t pg = svwhilelt_b16(np, n);
|
||||
svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
|
||||
svfloat16_t out = svmul_f16_m(pg, hy, vx);
|
||||
svst1_f16(pg, (__fp16 *)(y + np), out);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
// todo: RVV impl
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#else
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
}
|
||||
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#endif
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
@@ -1416,6 +1416,16 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (i == 0) {
|
||||
y[i] = x[i];
|
||||
} else {
|
||||
y[i] = y[i - 1] + x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
|
||||
ggml_float sum = 0.0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
|
||||
@@ -44,7 +44,7 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
|
||||
|
||||
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
|
||||
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
|
||||
|
||||
@@ -21,10 +21,12 @@
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cfloat>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
@@ -84,12 +86,12 @@
|
||||
|
||||
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
|
||||
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
|
||||
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
|
||||
#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000
|
||||
|
||||
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
|
||||
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
|
||||
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
|
||||
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
|
||||
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1)
|
||||
#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1)
|
||||
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
||||
# define GGML_CUDA_USE_CUB
|
||||
@@ -212,9 +214,9 @@ static const char * cu_get_error_str(CUresult err) {
|
||||
#define GGML_USE_VMM
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
|
||||
|
||||
#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
||||
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
||||
#define FP16_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
||||
#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
||||
|
||||
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
|
||||
#define FAST_FP16_AVAILABLE
|
||||
@@ -224,6 +226,10 @@ static const char * cu_get_error_str(CUresult err) {
|
||||
#define AMD_MFMA_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
|
||||
|
||||
#if defined(GGML_USE_HIP) && defined(RDNA4)
|
||||
#define AMD_WMMA_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) && defined(RDNA4)
|
||||
|
||||
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
|
||||
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#define VOLTA_MMA_AVAILABLE
|
||||
@@ -246,12 +252,14 @@ static const char * cu_get_error_str(CUresult err) {
|
||||
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
|
||||
|
||||
static bool fp16_available(const int cc) {
|
||||
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
|
||||
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
|
||||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
|
||||
}
|
||||
|
||||
static bool fast_fp16_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_AMD(cc) ||
|
||||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
|
||||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) ||
|
||||
(GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc));
|
||||
}
|
||||
|
||||
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||
@@ -268,7 +276,9 @@ static bool fp16_mma_hardware_available(const int cc) {
|
||||
}
|
||||
|
||||
static bool bf16_mma_hardware_available(const int cc) {
|
||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
|
||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||
|
||||
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 ||
|
||||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
|
||||
}
|
||||
|
||||
static bool fp32_mma_hardware_available(const int cc) {
|
||||
@@ -283,6 +293,10 @@ static bool amd_mfma_available(const int cc) {
|
||||
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
|
||||
}
|
||||
|
||||
static bool amd_wmma_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_RDNA4(cc);
|
||||
}
|
||||
|
||||
static bool volta_mma_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
|
||||
}
|
||||
@@ -550,8 +564,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v
|
||||
acc += v.y*u.y;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
|
||||
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
|
||||
#define V_DOT2_F32_F16_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
|
||||
|
||||
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
|
||||
#else
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
@@ -563,7 +581,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
|
||||
acc += tmpv.x * tmpu.x;
|
||||
acc += tmpv.y * tmpu.y;
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
|
||||
@@ -586,6 +604,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v,
|
||||
// If dst and src point at different address spaces then they are guaranteed to not be aliased.
|
||||
template <int nbytes, int alignment = 0>
|
||||
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
|
||||
static_assert(
|
||||
nbytes <= ggml_cuda_get_max_cpy_bytes() || alignment == 0,
|
||||
"You are misusing the alignment parameter for ggml_cuda_memcpy_1. "
|
||||
"The intent is for the parameter is only as a workaround if either one of the pointers is not properly aligned. "
|
||||
"If you use it to do more bytes per copy than ggml_cuda_max_cpy_bytes() the reads and writes may not be coalesced. "
|
||||
"Call ggml_cuda_memcpy_1 in a loop instead.");
|
||||
if constexpr (alignment != 0) {
|
||||
static_assert(nbytes % alignment == 0, "bad alignment");
|
||||
}
|
||||
@@ -958,6 +982,154 @@ struct ggml_cuda_graph {
|
||||
#endif
|
||||
};
|
||||
|
||||
struct ggml_cuda_concurrent_event {
|
||||
std::vector<cudaEvent_t> join_events;
|
||||
cudaEvent_t fork_event = nullptr;
|
||||
|
||||
int n_streams = 0;
|
||||
std::unordered_map<const ggml_tensor *, int> stream_mapping;
|
||||
|
||||
const ggml_tensor * join_node;
|
||||
|
||||
ggml_cuda_concurrent_event() = default;
|
||||
|
||||
ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;
|
||||
ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;
|
||||
|
||||
explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
|
||||
join_events.resize(n_streams);
|
||||
|
||||
for (size_t i = 0; i < join_events.size(); ++i) {
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
|
||||
}
|
||||
|
||||
ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept
|
||||
: join_events(std::move(other.join_events))
|
||||
, fork_event(other.fork_event)
|
||||
, n_streams(other.n_streams)
|
||||
, stream_mapping(std::move(other.stream_mapping))
|
||||
, join_node(other.join_node) {
|
||||
other.fork_event = nullptr;
|
||||
}
|
||||
|
||||
// 1. check if any branches write to overlapping memory ranges (except the join node)
|
||||
// 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
|
||||
// we assume all nodes have the same buffer
|
||||
bool is_valid() const {
|
||||
std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;
|
||||
write_ranges.resize(n_streams);
|
||||
|
||||
// get join_node's memory range to exclude from overlap checking.
|
||||
// multiple nodes can use join_node's buffer; we synchronize on the join node.
|
||||
const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node;
|
||||
const int64_t join_start = (int64_t) join_t->data;
|
||||
const int64_t join_end = join_start + ggml_nbytes(join_t);
|
||||
|
||||
for (const auto & [tensor, stream] : stream_mapping) {
|
||||
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
|
||||
const int64_t t_start = (int64_t) t->data;
|
||||
const int64_t t_end = t_start + ggml_nbytes(t);
|
||||
|
||||
// skip tensors that overlap with join_node's buffer.
|
||||
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// concurrent streams begin from 1
|
||||
write_ranges[stream - 1].emplace_back(t_start, t_end);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_streams; ++i) {
|
||||
// sorts first by start then by end of write range
|
||||
std::sort(write_ranges[i].begin(), write_ranges[i].end());
|
||||
}
|
||||
|
||||
bool writes_overlap = false;
|
||||
bool dependent_srcs = false;
|
||||
for (const auto & [tensor, stream] : stream_mapping) {
|
||||
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
|
||||
const int64_t t_start = (int64_t) t->data;
|
||||
const int64_t t_end = t_start + ggml_nbytes(t);
|
||||
|
||||
// skip tensors that overlap with join_node's buffer
|
||||
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// check if this buffer's write data overlaps with another stream's
|
||||
std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);
|
||||
for (int i = 0; i < n_streams; ++i) {
|
||||
if (i == stream - 1) {
|
||||
continue;
|
||||
}
|
||||
auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);
|
||||
|
||||
if (it != write_ranges[i].end()) {
|
||||
const std::pair<int64_t, int64_t> & other = *it;
|
||||
|
||||
// std::lower_bound returns the first element where other >= data_range (lexicographically).
|
||||
// This guarantees other.first >= data_range.first.
|
||||
// Therefore, overlap occurs iff other.first < data_range.second
|
||||
// (i.e., the other range starts before this range ends).
|
||||
if (other.first < data_range.second) {
|
||||
GGML_LOG_DEBUG("Writes overlap for %s", tensor->name);
|
||||
writes_overlap = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//check if all srcs are either in branch or don't have a branch
|
||||
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
||||
if (!tensor->src[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto it = stream_mapping.find(tensor->src[i]);
|
||||
|
||||
if (it == stream_mapping.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (it->second != stream) {
|
||||
dependent_srcs = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (dependent_srcs || writes_overlap) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return !writes_overlap && !dependent_srcs;
|
||||
}
|
||||
|
||||
~ggml_cuda_concurrent_event() {
|
||||
if (fork_event != nullptr) {
|
||||
CUDA_CHECK(cudaEventDestroy(fork_event));
|
||||
}
|
||||
for (cudaEvent_t e : join_events) {
|
||||
if (e != nullptr) {
|
||||
CUDA_CHECK(cudaEventDestroy(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_cuda_stream_context {
|
||||
std::vector<const ggml_tensor *> original_nodes;
|
||||
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
|
||||
|
||||
void reset() {
|
||||
original_nodes.clear();
|
||||
concurrent_events.clear();
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_backend_cuda_context {
|
||||
int device;
|
||||
std::string name;
|
||||
@@ -968,11 +1140,15 @@ struct ggml_backend_cuda_context {
|
||||
|
||||
std::unique_ptr<ggml_cuda_graph> cuda_graph;
|
||||
|
||||
int curr_stream_no = 0;
|
||||
|
||||
explicit ggml_backend_cuda_context(int device) :
|
||||
device(device),
|
||||
name(GGML_CUDA_NAME + std::to_string(device)) {
|
||||
}
|
||||
|
||||
ggml_cuda_stream_context concurrent_stream_context;
|
||||
|
||||
~ggml_backend_cuda_context();
|
||||
|
||||
cudaStream_t stream(int device, int stream) {
|
||||
@@ -983,9 +1159,9 @@ struct ggml_backend_cuda_context {
|
||||
return streams[device][stream];
|
||||
}
|
||||
|
||||
cudaStream_t stream() {
|
||||
return stream(device, 0);
|
||||
}
|
||||
cudaStream_t stream() { return stream(device, curr_stream_no); }
|
||||
|
||||
ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }
|
||||
|
||||
cublasHandle_t cublas_handle(int device) {
|
||||
if (cublas_handles[device] == nullptr) {
|
||||
@@ -1001,15 +1177,15 @@ struct ggml_backend_cuda_context {
|
||||
}
|
||||
|
||||
// pool
|
||||
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
|
||||
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
|
||||
|
||||
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
|
||||
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
|
||||
|
||||
ggml_cuda_pool & pool(int device) {
|
||||
if (pools[device] == nullptr) {
|
||||
pools[device] = new_pool_for_device(device);
|
||||
if (pools[device][curr_stream_no] == nullptr) {
|
||||
pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
|
||||
}
|
||||
return *pools[device];
|
||||
return *pools[device][curr_stream_no];
|
||||
}
|
||||
|
||||
ggml_cuda_pool & pool() {
|
||||
|
||||
@@ -39,6 +39,15 @@ template<typename dst_t, typename src_t>
|
||||
return __float2bfloat16(float(x));
|
||||
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
|
||||
return __bfloat162float(x);
|
||||
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
|
||||
return __float22half2_rn(x);
|
||||
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
|
||||
// bypass compile error on cuda 12.0.1
|
||||
#ifdef GGML_USE_HIP
|
||||
return __float22bfloat162_rn(x);
|
||||
#else
|
||||
return {x.x, x.y};
|
||||
#endif // GGML_USE_HIP
|
||||
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
|
||||
return int32_t(x);
|
||||
} else {
|
||||
|
||||
@@ -212,6 +212,6 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
|
||||
static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) {
|
||||
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
|
||||
}
|
||||
|
||||
@@ -12,10 +12,10 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
|
||||
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
|
||||
|
||||
template <cpy_kernel_t cpy_1>
|
||||
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
@@ -40,7 +40,7 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
|
||||
static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
@@ -86,6 +86,9 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,
|
||||
nb12, nb13);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
||||
@@ -166,7 +169,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
||||
static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
@@ -180,17 +183,17 @@ static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const in
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static void ggml_cpy_flt_contiguous_cuda(
|
||||
static void ggml_cpy_scalar_contiguous_cuda(
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t, bool transposed = false>
|
||||
static void ggml_cpy_flt_cuda(
|
||||
static void ggml_cpy_scalar_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
@@ -202,7 +205,7 @@ static void ggml_cpy_flt_cuda(
|
||||
ne00n = ne00;
|
||||
ne01n = ne01;
|
||||
ne02n = ne02;
|
||||
} else if (nb00 > nb02) {
|
||||
} else {
|
||||
ne00n = ne00;
|
||||
ne01n = ne01*ne02;
|
||||
ne02n = 1;
|
||||
@@ -212,11 +215,11 @@ static void ggml_cpy_flt_cuda(
|
||||
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
|
||||
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
|
||||
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
||||
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
|
||||
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
|
||||
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
} else {
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
}
|
||||
@@ -384,7 +387,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
char * src1_ddc = (char *) src1->data;
|
||||
|
||||
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
||||
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1;
|
||||
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
|
||||
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
|
||||
|
||||
if (src0->type == src1->type && contiguous_srcs) {
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
@@ -398,94 +402,132 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<float, float, true>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<float, float>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<float, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<float, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, half> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<float, half>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<float, half>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q8_0_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q8_0_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q4_0_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q4_0_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q4_1_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q4_1_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q5_0_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q5_0_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_iq4_nl_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q5_1_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q5_1_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<half, half, true>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<half, half>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<half, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<half, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<half, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<half, float>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<half, float>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, half>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<nv_bfloat16, half>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, float>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<nv_bfloat16, float>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_scalar_cuda<int32_t, int32_t, true>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
ggml_cpy_scalar_cuda<int32_t, int32_t>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<float, int32_t>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<float, int32_t>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<int32_t, float>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<int32_t, float>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
|
||||
@@ -55,11 +55,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
||||
#endif // FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -609,7 +609,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
||||
float KQ_sum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
|
||||
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
|
||||
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast<uint32_t>(k_VKQ_sup) ?
|
||||
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
|
||||
KQ_sum_add += val;
|
||||
tmp[i0/(np*warp_size)][jc1] = val;
|
||||
|
||||
@@ -86,11 +86,11 @@ static __global__ void flash_attn_ext_vec(
|
||||
|
||||
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
|
||||
#else
|
||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
@@ -112,13 +112,13 @@ static __global__ void flash_attn_ext_vec(
|
||||
|
||||
constexpr int ne_KQ = ncols*D;
|
||||
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
||||
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
||||
#else
|
||||
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
||||
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
|
||||
float KQ_max[ncols];
|
||||
float KQ_sum[ncols];
|
||||
@@ -129,11 +129,11 @@ static __global__ void flash_attn_ext_vec(
|
||||
}
|
||||
|
||||
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
|
||||
#else
|
||||
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
||||
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
||||
if constexpr (Q_q8_1) {
|
||||
@@ -155,7 +155,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) {
|
||||
if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
|
||||
tmp_q_i32[i] = 0;
|
||||
}
|
||||
}
|
||||
@@ -191,7 +191,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
|
||||
__syncthreads();
|
||||
} else {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
const half2 scale_h2 = make_half2(scale, scale);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
@@ -233,7 +233,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
Q_reg[j][k].y *= scale;
|
||||
}
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
}
|
||||
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
||||
@@ -272,7 +272,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
|
||||
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
|
||||
|
||||
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) {
|
||||
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
|
||||
KQ_reg[j] = sum;
|
||||
}
|
||||
}
|
||||
@@ -291,7 +291,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
|
||||
KQ[j*nthreads + tid] = KQ_reg[j];
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
@@ -303,7 +303,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
||||
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
}
|
||||
|
||||
#ifndef GGML_USE_HIP
|
||||
@@ -314,7 +314,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
|
||||
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
half2 KQ_k[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
@@ -353,7 +353,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
@@ -374,7 +374,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
|
||||
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
@@ -386,7 +386,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
||||
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
@@ -421,7 +421,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
|
||||
KQ_max[j_VKQ] = kqmax_new;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
|
||||
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
|
||||
|
||||
@@ -452,7 +452,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
|
||||
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
|
||||
KQ_sum[j_VKQ] *= kqmax_scale;
|
||||
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
|
||||
|
||||
@@ -53,6 +53,7 @@
|
||||
#include "ggml-cuda/set.cuh"
|
||||
#include "ggml-cuda/set-rows.cuh"
|
||||
#include "ggml-cuda/pad_reflect_1d.cuh"
|
||||
#include "ggml-cuda/solve_tri.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -521,7 +522,8 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
};
|
||||
#endif // defined(GGML_USE_VMM)
|
||||
|
||||
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
|
||||
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device,
|
||||
[[maybe_unused]] int stream_no) {
|
||||
#if defined(GGML_USE_VMM)
|
||||
if (ggml_cuda_info().devices[device].vmm) {
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
|
||||
@@ -2527,6 +2529,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
ggml_cuda_op_trunc(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
ggml_cuda_op_expm1(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
ggml_cuda_op_softplus(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -2711,6 +2719,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
ggml_cuda_opt_step_sgd(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
ggml_cuda_op_solve_tri(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -2992,6 +3003,40 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
}
|
||||
#endif
|
||||
|
||||
static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
|
||||
const ggml_tensor * view,
|
||||
const ggml_tensor * set_rows) {
|
||||
|
||||
if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) {
|
||||
return false;
|
||||
}
|
||||
// ne3 not tested
|
||||
if (rope->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (set_rows->src[1]->type != GGML_TYPE_I64) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The view should flatten two dims of rope into one dim
|
||||
if (!ggml_is_contiguous(view) || view->ne[0] != rope->ne[0] * rope->ne[1]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Only norm/neox shaders have the fusion code
|
||||
const int mode = ((const int32_t *) rope->op_params)[2];
|
||||
if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
|
||||
#ifndef NDEBUG
|
||||
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
|
||||
@@ -3006,7 +3051,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
|
||||
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
|
||||
|
||||
if (ops.size() == topk_moe_ops_with_norm.size() &&
|
||||
const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
|
||||
const std::initializer_list<enum ggml_op> & list2) {
|
||||
return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
|
||||
};
|
||||
|
||||
if (is_equal(topk_moe_ops_with_norm, ops) &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
|
||||
@@ -3016,8 +3066,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
}
|
||||
}
|
||||
|
||||
if (ops.size() == topk_moe_ops.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
|
||||
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||
@@ -3025,7 +3074,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
}
|
||||
}
|
||||
|
||||
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
|
||||
if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
|
||||
@@ -3041,9 +3090,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
|
||||
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
|
||||
|
||||
if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) ||
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) {
|
||||
|
||||
if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) {
|
||||
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
|
||||
@@ -3055,9 +3103,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
}
|
||||
}
|
||||
|
||||
if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) ||
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) {
|
||||
|
||||
if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
|
||||
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
|
||||
@@ -3067,6 +3114,18 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
}
|
||||
}
|
||||
|
||||
std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };
|
||||
|
||||
if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
|
||||
const ggml_tensor * rope = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * view = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];
|
||||
|
||||
if (ggml_cuda_should_fuse_rope_set_rows(rope, view, set_rows)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||
return false;
|
||||
}
|
||||
@@ -3142,27 +3201,94 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
// flag used to determine whether it is an integrated_gpu
|
||||
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
||||
|
||||
ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
|
||||
bool is_concurrent_event_active = false;
|
||||
ggml_cuda_concurrent_event * concurrent_event = nullptr;
|
||||
bool should_launch_concurrent_events = false;
|
||||
|
||||
const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
|
||||
if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
|
||||
concurrent_event = &stream_ctx.concurrent_events[node];
|
||||
|
||||
is_concurrent_event_active = true;
|
||||
|
||||
GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
|
||||
|
||||
cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0
|
||||
GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
|
||||
CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
|
||||
|
||||
for (int i = 1; i <= concurrent_event->n_streams; ++i) {
|
||||
cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
|
||||
CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
while (!graph_evaluated_or_captured) {
|
||||
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
||||
// With the use of CUDA graphs, the execution will be performed by the graph launch.
|
||||
if (!use_cuda_graph || cuda_graph_update_required) {
|
||||
|
||||
[[maybe_unused]] int prev_i = 0;
|
||||
|
||||
if (stream_ctx.concurrent_events.size() > 0) {
|
||||
should_launch_concurrent_events = true;
|
||||
for (const auto & [tensor, event] : stream_ctx.concurrent_events) {
|
||||
should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
|
||||
}
|
||||
}
|
||||
if (should_launch_concurrent_events) {
|
||||
//Restore the original graph to enable fusion within the streams
|
||||
cgraph->nodes = const_cast<ggml_tensor **>(stream_ctx.original_nodes.data());
|
||||
cgraph->n_nodes = (int) stream_ctx.original_nodes.size();
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
if (is_concurrent_event_active) {
|
||||
GGML_ASSERT(concurrent_event);
|
||||
|
||||
if (node == concurrent_event->join_node) {
|
||||
cuda_ctx->curr_stream_no = 0;
|
||||
for (int i = 1; i <= concurrent_event->n_streams; ++i) {
|
||||
// Wait on join events of forked streams in the main stream
|
||||
CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1],
|
||||
cuda_ctx->stream(cuda_ctx->device, i)));
|
||||
CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1]));
|
||||
}
|
||||
|
||||
is_concurrent_event_active = false;
|
||||
concurrent_event = nullptr;
|
||||
} else {
|
||||
GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end());
|
||||
cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
|
||||
GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
|
||||
}
|
||||
} else if (i - prev_i > 1) {
|
||||
//the previous node was fused
|
||||
const ggml_tensor * prev_node = cgraph->nodes[i - 1];
|
||||
try_launch_concurrent_event(prev_node);
|
||||
|
||||
if (is_concurrent_event_active) {
|
||||
cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
|
||||
GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_CUDA_DEBUG
|
||||
const int nodes_fused = i - prev_i - 1;
|
||||
prev_i = i;
|
||||
if (nodes_fused > 0) {
|
||||
GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
|
||||
}
|
||||
#endif
|
||||
prev_i = i;
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
// start of fusion operations
|
||||
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
||||
if (!disable_fusion) {
|
||||
|
||||
@@ -3196,6 +3322,15 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
|
||||
ggml_tensor * rope = cgraph->nodes[i];
|
||||
ggml_tensor * set_rows = cgraph->nodes[i + 2];
|
||||
|
||||
ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD) {
|
||||
int n_fuse = 0;
|
||||
ggml_op ops[8];
|
||||
@@ -3446,13 +3581,17 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(integrated);
|
||||
#endif // NDEBUG
|
||||
#endif // NDEBUG
|
||||
|
||||
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
|
||||
if (!ok) {
|
||||
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
}
|
||||
GGML_ASSERT(ok);
|
||||
|
||||
if (!is_concurrent_event_active) {
|
||||
try_launch_concurrent_event(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3592,6 +3731,235 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
||||
|
||||
static bool enable_graph_optimization = [] {
|
||||
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
|
||||
return env != nullptr && atoi(env) == 1;
|
||||
}();
|
||||
|
||||
if (!enable_graph_optimization) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend");
|
||||
GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes);
|
||||
|
||||
ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
|
||||
stream_context.reset();
|
||||
|
||||
// number of out-degrees for a particular node
|
||||
std::unordered_map<const ggml_tensor *, int> fan_out;
|
||||
// reverse mapping of node to index in the cgraph
|
||||
std::unordered_map<const ggml_tensor *, int> node_indices;
|
||||
|
||||
const auto & is_noop = [](const ggml_tensor * node) -> bool {
|
||||
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE ||
|
||||
node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
|
||||
};
|
||||
|
||||
const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool {
|
||||
for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
|
||||
if (dst->src[s] == src) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// implicit dependency if they view the same tensor
|
||||
const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst;
|
||||
const ggml_tensor * src2 = src->view_src ? src->view_src : src;
|
||||
if (dst2 == src2) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
|
||||
const ggml_tensor * node = cgraph->nodes[node_idx];
|
||||
node_indices[node] = node_idx;
|
||||
|
||||
if (is_noop(node)) {
|
||||
continue;
|
||||
}
|
||||
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
|
||||
const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx];
|
||||
//TODO: check why nrows > 1 fails
|
||||
if (node && !is_noop(node) && ggml_nrows(node) <= 1) {
|
||||
fan_out[src] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Target Q, K, V for concurrency
|
||||
// this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else):
|
||||
// 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm")
|
||||
// 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn")
|
||||
// 3. account for all branches from the fork to the join
|
||||
// 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details)
|
||||
// 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030
|
||||
|
||||
const int min_fan_out = 3;
|
||||
const int max_fan_out = 3;
|
||||
|
||||
// store {fork_idx, join_idx}
|
||||
std::vector<std::pair<int, int>> concurrent_node_ranges;
|
||||
|
||||
// save the original nodes
|
||||
std::vector<const ggml_tensor *> original_nodes;
|
||||
original_nodes.reserve(cgraph->n_nodes);
|
||||
for (int i = 0; i < cgraph->n_nodes; ++i) {
|
||||
original_nodes.push_back(cgraph->nodes[i]);
|
||||
}
|
||||
cuda_ctx->stream_context().original_nodes = std::move(original_nodes);
|
||||
|
||||
for (const auto & [root_node, count] : fan_out) {
|
||||
if (count >= min_fan_out && count <= max_fan_out) {
|
||||
const int root_node_idx = node_indices[root_node];
|
||||
|
||||
bool is_part_of_event = false;
|
||||
for (const auto & [start, end] : concurrent_node_ranges) {
|
||||
if (root_node_idx >= start && root_node_idx <= end) {
|
||||
is_part_of_event = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_part_of_event) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<std::vector<const ggml_tensor *>> nodes_per_branch;
|
||||
for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
|
||||
const ggml_tensor * node = cgraph->nodes[i];
|
||||
if (!is_noop(node) && depends_on(node, root_node)) {
|
||||
nodes_per_branch.push_back({ node });
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(nodes_per_branch.size() == (size_t) count);
|
||||
|
||||
//find the join point
|
||||
const ggml_tensor * join_node = nullptr;
|
||||
|
||||
const auto & belongs_to_branch = [&](const ggml_tensor * node,
|
||||
const std::vector<const ggml_tensor *> & branch) -> bool {
|
||||
for (const ggml_tensor * n : branch) {
|
||||
if (depends_on(node, n)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
|
||||
const ggml_tensor * curr_node = cgraph->nodes[i];
|
||||
|
||||
int num_joins = 0;
|
||||
for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
|
||||
if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) {
|
||||
num_joins++;
|
||||
}
|
||||
}
|
||||
|
||||
if (num_joins >= 2) {
|
||||
join_node = curr_node;
|
||||
break;
|
||||
}
|
||||
|
||||
bool found_branch = false;
|
||||
for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
|
||||
std::vector<const ggml_tensor *> & branch_vec = nodes_per_branch[branch_idx];
|
||||
if (belongs_to_branch(curr_node, branch_vec)) {
|
||||
//continue accumulating
|
||||
if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) {
|
||||
branch_vec.push_back(curr_node);
|
||||
}
|
||||
found_branch = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found_branch && is_noop(curr_node)) {
|
||||
// we can put it in any branch because it will be ignored
|
||||
nodes_per_branch[0].push_back({ curr_node });
|
||||
}
|
||||
}
|
||||
|
||||
if (join_node) {
|
||||
//Create ggml_cuda_concurrent_event
|
||||
ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size());
|
||||
concurrent_event.join_node = join_node;
|
||||
|
||||
for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
|
||||
for (const ggml_tensor * n : nodes_per_branch[branch_idx]) {
|
||||
concurrent_event.stream_mapping[n] = branch_idx + 1;
|
||||
}
|
||||
}
|
||||
|
||||
int fork_node_idx = node_indices[root_node];
|
||||
int join_node_idx = node_indices[join_node];
|
||||
|
||||
int current_branch_idx = 0;
|
||||
int current_node_idx = fork_node_idx + 1;
|
||||
const int n_branches = nodes_per_branch.size();
|
||||
|
||||
int total_branch_nodes = 0;
|
||||
for (std::vector<const ggml_tensor *> branch_nodes : nodes_per_branch) {
|
||||
total_branch_nodes += branch_nodes.size();
|
||||
}
|
||||
|
||||
// there are other nodes in the middle which are unaccounted for
|
||||
// usually (cpy) nodes, then ignore this fork
|
||||
if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) {
|
||||
GGML_LOG_DEBUG(
|
||||
"Skipping %s because the number of nodes in the middle is not equal to the total number of "
|
||||
"branch nodes %d != %d\n",
|
||||
root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
|
||||
GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
|
||||
concurrent_events.emplace(root_node, std::move(concurrent_event));
|
||||
GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node);
|
||||
concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx);
|
||||
|
||||
// interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
|
||||
// example transformation:
|
||||
// [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] ->
|
||||
// [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn]
|
||||
while (current_node_idx < join_node_idx) {
|
||||
std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx];
|
||||
|
||||
bool has_node = false;
|
||||
for (std::vector<const ggml_tensor *> branch_node : nodes_per_branch) {
|
||||
has_node |= branch_node.size() > 0;
|
||||
}
|
||||
|
||||
GGML_ASSERT(has_node);
|
||||
|
||||
if (branch_nodes.empty()) {
|
||||
current_branch_idx = (current_branch_idx + 1) % n_branches;
|
||||
continue;
|
||||
}
|
||||
|
||||
cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
|
||||
current_node_idx++;
|
||||
branch_nodes.erase(branch_nodes.begin());
|
||||
|
||||
// append all empty nodes
|
||||
while (!branch_nodes.empty() && is_noop(branch_nodes.front())) {
|
||||
cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
|
||||
current_node_idx++;
|
||||
branch_nodes.erase(branch_nodes.begin());
|
||||
}
|
||||
|
||||
current_branch_idx = (current_branch_idx + 1) % n_branches;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static const ggml_backend_i ggml_backend_cuda_interface = {
|
||||
/* .get_name = */ ggml_backend_cuda_get_name,
|
||||
/* .free = */ ggml_backend_cuda_free,
|
||||
@@ -3606,7 +3974,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
|
||||
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
|
||||
/* .event_record = */ ggml_backend_cuda_event_record,
|
||||
/* .event_wait = */ ggml_backend_cuda_event_wait,
|
||||
/* .graph_optimize = */ NULL,
|
||||
/* .graph_optimize = */ ggml_backend_cuda_graph_optimize,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_cuda_guid() {
|
||||
@@ -3689,10 +4057,110 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
|
||||
return ctx->description.c_str();
|
||||
}
|
||||
|
||||
#if defined(__linux__)
|
||||
// Helper function to get available memory from /proc/meminfo for UMA systems
|
||||
static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_kb, long * free_swap_kb) {
|
||||
FILE * meminfo_file = nullptr;
|
||||
// 2KB buffer for reading /proc/meminfo since it does not report size info, should be enough
|
||||
const size_t BUFFER_SIZE = 2048;
|
||||
auto file_buffer = std::make_unique<char[]>(BUFFER_SIZE);
|
||||
size_t bytes_read = 0;
|
||||
long huge_tlb_total_pages = -1;
|
||||
long huge_tlb_free_pages = -1;
|
||||
long huge_tlb_page_size = -1;
|
||||
|
||||
if (available_memory_kb == nullptr || free_swap_kb == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
meminfo_file = fopen("/proc/meminfo", "r");
|
||||
if (meminfo_file == nullptr) {
|
||||
GGML_LOG_ERROR("%s: failed to open /proc/meminfo\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read file into buffer
|
||||
bytes_read = fread(file_buffer.get(), 1, BUFFER_SIZE - 1, meminfo_file);
|
||||
fclose(meminfo_file);
|
||||
|
||||
if (bytes_read == 0) {
|
||||
GGML_LOG_ERROR("%s: failed to read from /proc/meminfo\n", __func__);
|
||||
return false;
|
||||
}
|
||||
file_buffer[bytes_read] = '\0';
|
||||
|
||||
*available_memory_kb = -1;
|
||||
*free_swap_kb = -1;
|
||||
|
||||
// Parse the file buffer line by line
|
||||
char * line = file_buffer.get();
|
||||
char * line_next;
|
||||
while (line < file_buffer.get() + bytes_read) {
|
||||
// Find the end of the current line
|
||||
line_next = strchr(line, '\n');
|
||||
if (line_next != nullptr) {
|
||||
*line_next = '\0';
|
||||
line_next++;
|
||||
} else {
|
||||
line_next = file_buffer.get() + bytes_read;
|
||||
}
|
||||
|
||||
long value;
|
||||
if (sscanf(line, "MemAvailable: %ld kB", &value) == 1) {
|
||||
*available_memory_kb = value;
|
||||
} else if (sscanf(line, "SwapFree: %ld kB", &value) == 1) {
|
||||
*free_swap_kb = value;
|
||||
} else if (sscanf(line, "HugePages_Total: %ld", &value) == 1) {
|
||||
huge_tlb_total_pages = value;
|
||||
} else if (sscanf(line, "HugePages_Free: %ld", &value) == 1) {
|
||||
huge_tlb_free_pages = value;
|
||||
} else if (sscanf(line, "Hugepagesize: %ld kB", &value) == 1) {
|
||||
huge_tlb_page_size = value;
|
||||
}
|
||||
|
||||
line = line_next;
|
||||
}
|
||||
|
||||
if (huge_tlb_total_pages != 0 && huge_tlb_total_pages != -1) {
|
||||
*available_memory_kb = huge_tlb_free_pages * huge_tlb_page_size;
|
||||
|
||||
// Hugetlbfs pages are not swappable.
|
||||
*free_swap_kb = 0;
|
||||
}
|
||||
|
||||
GGML_LOG_DEBUG("%s: final available_memory_kb: %ld\n", __func__, *available_memory_kb);
|
||||
return true;
|
||||
}
|
||||
#endif // defined(__linux__)
|
||||
|
||||
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
ggml_cuda_set_device(ctx->device);
|
||||
CUDA_CHECK(cudaMemGetInfo(free, total));
|
||||
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/17368
|
||||
#if defined(__linux__)
|
||||
// Check if this is a UMA (Unified Memory Architecture) system
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device));
|
||||
|
||||
// Check if UMA is explicitly enabled via environment variable
|
||||
bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
|
||||
bool is_uma = prop.integrated > 0 || uma_env;
|
||||
|
||||
if (is_uma) {
|
||||
// For UMA systems (like DGX Spark), use system memory info
|
||||
long available_memory_kb = 0;
|
||||
long free_swap_kb = 0;
|
||||
|
||||
if (ggml_backend_cuda_get_available_uma_memory(&available_memory_kb, &free_swap_kb) && available_memory_kb > 0) {
|
||||
*free = (size_t)available_memory_kb * 1024;
|
||||
} else {
|
||||
GGML_LOG_ERROR("%s: /proc/meminfo reading failed, using cudaMemGetInfo\n", __func__);
|
||||
}
|
||||
}
|
||||
#endif // defined(__linux__)
|
||||
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
|
||||
@@ -3780,6 +4248,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
@@ -3954,6 +4424,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
|
||||
return true;
|
||||
}
|
||||
@@ -4091,6 +4564,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return true;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ namespace ggml_cuda_mma {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 64;
|
||||
T x[ne] = {0};
|
||||
|
||||
@@ -149,6 +149,34 @@ namespace ggml_cuda_mma {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(RDNA4)
|
||||
static constexpr int ne = I * J / 32;
|
||||
T x[ne] = {0};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 16) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 16) {
|
||||
return 8 * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 16) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
static constexpr int ne = I * J / 32;
|
||||
T x[ne] = {0};
|
||||
@@ -236,6 +264,32 @@ namespace ggml_cuda_mma {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 32;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return 4 * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
@@ -285,6 +339,34 @@ namespace ggml_cuda_mma {
|
||||
struct tile<I_, J_, nv_bfloat162> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 32;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return 4 * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
@@ -320,6 +402,7 @@ namespace ggml_cuda_mma {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
};
|
||||
|
||||
template <int I, int J>
|
||||
@@ -353,6 +436,30 @@ namespace ggml_cuda_mma {
|
||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
||||
xi[0] = xs[0];
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
||||
|
||||
} else if constexpr (std::is_same_v<T, int>) {
|
||||
if constexpr (I == 16 && J == 4) {
|
||||
int64_t * xi = (int64_t *) t.x;
|
||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
||||
xi[0] = xs[0];
|
||||
|
||||
}else if constexpr (I == 16 && J == 8) {
|
||||
int64_t * xi = (int64_t *) t.x;
|
||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
|
||||
xi[0] = xs[0];
|
||||
|
||||
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
|
||||
xi[1] = xs1[0];
|
||||
|
||||
}else{
|
||||
NO_DEVICE_CODE;
|
||||
}
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int l = 0; l < t.ne; ++l) {
|
||||
@@ -639,12 +746,34 @@ namespace ggml_cuda_mma {
|
||||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
|
||||
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
||||
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
||||
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
|
||||
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // TURING_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
|
||||
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
||||
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
||||
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
|
||||
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMPERE_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
@@ -665,6 +794,36 @@ namespace ggml_cuda_mma {
|
||||
acc[0],
|
||||
0, 0, 0);
|
||||
#endif // defined(CDNA3)
|
||||
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||
int32x2_t * a_vec = (int32x2_t *) A.x;
|
||||
int32x2_t * b_vec = (int32x2_t *) B.x;
|
||||
|
||||
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
||||
int32x8_t * acc = (int32x8_t *) D.x;
|
||||
|
||||
#if defined(RDNA4)
|
||||
|
||||
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
||||
true,
|
||||
a_vec[0],
|
||||
true,
|
||||
b_vec[0],
|
||||
acc[0],
|
||||
true
|
||||
);
|
||||
|
||||
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
||||
true,
|
||||
a_vec[1],
|
||||
true,
|
||||
b_vec[1],
|
||||
acc[0],
|
||||
true
|
||||
);
|
||||
#endif // defined(RDNA4)
|
||||
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
@@ -691,6 +850,7 @@ namespace ggml_cuda_mma {
|
||||
acc[0],
|
||||
0, 0, 0);
|
||||
#endif // defined(CDNA3)
|
||||
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
@@ -729,10 +889,37 @@ namespace ggml_cuda_mma {
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
|
||||
#else
|
||||
tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D;
|
||||
tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A;
|
||||
tile <16, 8, float> * D16 = reinterpret_cast<tile <16, 8, float> *>(&D);
|
||||
const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
|
||||
mma(D16[0], A16[0], B);
|
||||
mma(D16[1], A16[1], B);
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||
int32x2_t * a_vec = (int32x2_t *) A.x;
|
||||
int32x2_t * b_vec = (int32x2_t *) B.x;
|
||||
|
||||
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
||||
int32x8_t * acc = (int32x8_t *) D.x;
|
||||
|
||||
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
||||
true,
|
||||
a_vec[0],
|
||||
true,
|
||||
b_vec[0],
|
||||
acc[0],
|
||||
false
|
||||
);
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||
case GGML_TYPE_F32:
|
||||
return ampere_mma_available(cc);
|
||||
case GGML_TYPE_F16:
|
||||
return volta_mma_available(cc) || turing_mma_available(cc);
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
|
||||
case GGML_TYPE_BF16:
|
||||
return ampere_mma_available(cc);
|
||||
return ampere_mma_available(cc) || amd_wmma_available(cc);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "mma.cuh"
|
||||
#include "common.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
@@ -27,20 +28,35 @@ static __global__ void mul_mat_f(
|
||||
const int stride_col_id, const int stride_row_id,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
|
||||
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<tile_B_I, 8, T> tile_B;
|
||||
typedef tile<16, tile_C_J, float> tile_C;
|
||||
|
||||
constexpr bool a_supported = tile_A::supported();
|
||||
constexpr bool b_supported = tile_B::supported();
|
||||
constexpr bool c_supported = tile_C::supported();
|
||||
constexpr bool supported = a_supported && b_supported && c_supported;
|
||||
#else
|
||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||
|
||||
if (!I_16_supported && !I_32_supported) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
constexpr bool supported = I_16_supported || I_32_supported;
|
||||
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||
|
||||
typedef tile<I_preferred, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<I_preferred, 8, float> tile_C;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
if constexpr (!supported) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
@@ -161,11 +177,11 @@ static __global__ void mul_mat_f(
|
||||
|
||||
if constexpr (!has_ids) {
|
||||
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||
} else {
|
||||
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
||||
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -239,7 +255,7 @@ static __global__ void mul_mat_f(
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
//This kernel is for larger batch sizes of mul_mat_id
|
||||
@@ -253,20 +269,35 @@ static __global__ void mul_mat_f_ids(
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const uint3 sis1_fd, const uint3 nch_fd) {
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
|
||||
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<tile_B_I, 8, T> tile_B;
|
||||
typedef tile<16, tile_C_J, float> tile_C;
|
||||
|
||||
constexpr bool a_supported = tile_A::supported();
|
||||
constexpr bool b_supported = tile_B::supported();
|
||||
constexpr bool c_supported = tile_C::supported();
|
||||
constexpr bool supported = a_supported && b_supported && c_supported;
|
||||
#else
|
||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||
constexpr bool supported = I_16_supported || I_32_supported;
|
||||
|
||||
if (!I_16_supported && !I_32_supported) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||
|
||||
typedef tile<I_preferred, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<I_preferred, 8, float> tile_C;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
if constexpr (!supported) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
@@ -408,7 +439,7 @@ static __global__ void mul_mat_f_ids(
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const float2 tmp = vals_buf[curr_buf][j0];
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
@@ -492,7 +523,7 @@ static __global__ void mul_mat_f_ids(
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template<typename T, int cols_per_block, int nwarps>
|
||||
@@ -554,7 +585,8 @@ void mul_mat_f_cuda(
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||
typedef tile<16, 8, T> tile_A_16;
|
||||
typedef tile<32, 8, T> tile_A_32;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
typedef tile<16, 8, T> tile_B_16;
|
||||
typedef tile< 8, 8, T> tile_B_8;
|
||||
|
||||
GGML_ASSERT(ncols_x % 2 == 0);
|
||||
GGML_ASSERT(stride_row % 2 == 0);
|
||||
@@ -581,7 +613,8 @@ void mul_mat_f_cuda(
|
||||
|
||||
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
||||
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
|
||||
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
|
||||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user