mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
Compare commits
115 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d774ab3acc | ||
|
|
fa62da9b2d | ||
|
|
1ec208083c | ||
|
|
9f4cc8f8d3 | ||
|
|
fd08255d0d | ||
|
|
3ec9fd4b77 | ||
|
|
3962fc1a79 | ||
|
|
1bef571f6a | ||
|
|
db288b60cb | ||
|
|
106045e7bb | ||
|
|
f117d84b48 | ||
|
|
534c46b53c | ||
|
|
387a1598ca | ||
|
|
7c9e0ca520 | ||
|
|
8f8290ada9 | ||
|
|
b34aedd558 | ||
|
|
cde3833239 | ||
|
|
b3451785ac | ||
|
|
1d1e6a90bc | ||
|
|
5598f475be | ||
|
|
8ec05832fa | ||
|
|
21c84b5d2d | ||
|
|
d92cb67e37 | ||
|
|
6eecde3cc8 | ||
|
|
396856b400 | ||
|
|
4d0598e144 | ||
|
|
90f9b88afb | ||
|
|
864a0b67a6 | ||
|
|
84ec8a58f7 | ||
|
|
bfcce4d693 | ||
|
|
69804487e0 | ||
|
|
ff227703d6 | ||
|
|
0cec062a63 | ||
|
|
53debe6f3c | ||
|
|
cfd74c86db | ||
|
|
ecef206ccb | ||
|
|
5bbc7362cb | ||
|
|
aa6fb13213 | ||
|
|
a83f528688 | ||
|
|
b1bcd309fc | ||
|
|
5783575c9d | ||
|
|
4a2b196d03 | ||
|
|
1bd3047a93 | ||
|
|
a2df2787b3 | ||
|
|
553f1e46e9 | ||
|
|
8b576b6c55 | ||
|
|
27d135c970 | ||
|
|
6af1ca48cb | ||
|
|
c300e68ef4 | ||
|
|
3d804dec76 | ||
|
|
ffd0821c57 | ||
|
|
4314e56c4f | ||
|
|
496e5bf46b | ||
|
|
7919256c57 | ||
|
|
e0449763a4 | ||
|
|
eb7cf15a80 | ||
|
|
66ee4f297c | ||
|
|
e51c47b401 | ||
|
|
2711d0215f | ||
|
|
f0d4b29edf | ||
|
|
815857791d | ||
|
|
1a0e87d291 | ||
|
|
d2e518e9b4 | ||
|
|
b636228c0a | ||
|
|
325afb370a | ||
|
|
794fe23f29 | ||
|
|
cf8cc856d7 | ||
|
|
d0c08040b6 | ||
|
|
be5ef7963f | ||
|
|
cae9fb4361 | ||
|
|
7fee2889e6 | ||
|
|
d7d1eccacc | ||
|
|
4bf3119d61 | ||
|
|
f643120bad | ||
|
|
6e84b0ab8e | ||
|
|
2b8525d5c8 | ||
|
|
a4417ddda9 | ||
|
|
d6d24cd9ed | ||
|
|
a5203b4465 | ||
|
|
df984e0147 | ||
|
|
acd38efee3 | ||
|
|
caf773f249 | ||
|
|
178a7eb952 | ||
|
|
6f53d8a6b4 | ||
|
|
19f65187cb | ||
|
|
1d8ee06000 | ||
|
|
2cc9b8c32c | ||
|
|
f35726c2fb | ||
|
|
4a75d19376 | ||
|
|
26771a1491 | ||
|
|
ca6baf76c1 | ||
|
|
6e264a905b | ||
|
|
49b0e3cec4 | ||
|
|
20a758155b | ||
|
|
00c24acb2a | ||
|
|
466ea66f33 | ||
|
|
5f0db9522f | ||
|
|
c5d9effb49 | ||
|
|
9fbadaef4f | ||
|
|
9755129c27 | ||
|
|
a07c2c8a52 | ||
|
|
8137b4bb2b | ||
|
|
1af6945eb0 | ||
|
|
01f37edf1a | ||
|
|
c07e87f38b | ||
|
|
564804b79b | ||
|
|
05f63cc9ee | ||
|
|
f7fb43cd0b | ||
|
|
5845661640 | ||
|
|
f211d1dc10 | ||
|
|
955a6c2d91 | ||
|
|
1971adf55e | ||
|
|
5245729e33 | ||
|
|
6152129d05 | ||
|
|
16d3df7ab0 |
@@ -2,6 +2,10 @@ ARG UBUNTU_VERSION=22.04
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
ARG TARGETARCH
|
||||
|
||||
ARG GGML_CPU_ARM_ARCH=armv8-a
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential git cmake libcurl4-openssl-dev
|
||||
|
||||
@@ -9,7 +13,14 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN cmake -S . -B build -DGGML_BACKEND_DL=ON -DGGML_NATIVE=OFF -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_CURL=ON -DCMAKE_BUILD_TYPE=Release && \
|
||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=${GGML_CPU_ARM_ARCH}; \
|
||||
else \
|
||||
echo "Unsupported architecture"; \
|
||||
exit 1; \
|
||||
fi && \
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
|
||||
@@ -13,9 +13,13 @@ elif [[ "$arg1" == '--quantize' || "$arg1" == '-q' ]]; then
|
||||
exec ./llama-quantize "$@"
|
||||
elif [[ "$arg1" == '--run' || "$arg1" == '-r' ]]; then
|
||||
exec ./llama-cli "$@"
|
||||
elif [[ "$arg1" == '--bench' || "$arg1" == '-b' ]]; then
|
||||
exec ./llama-bench "$@"
|
||||
elif [[ "$arg1" == '--perplexity' || "$arg1" == '-p' ]]; then
|
||||
exec ./llama-perplexity "$@"
|
||||
elif [[ "$arg1" == '--all-in-one' || "$arg1" == '-a' ]]; then
|
||||
echo "Converting PTH to GGML..."
|
||||
for i in `ls $1/$2/ggml-model-f16.bin*`; do
|
||||
for i in $(ls $1/$2/ggml-model-f16.bin*); do
|
||||
if [ -f "${i/f16/q4_0}" ]; then
|
||||
echo "Skip model quantization, it already exists: ${i/f16/q4_0}"
|
||||
else
|
||||
@@ -30,6 +34,10 @@ else
|
||||
echo "Available commands: "
|
||||
echo " --run (-r): Run a model previously converted into ggml"
|
||||
echo " ex: -m /models/7B/ggml-model-q4_0.bin -p \"Building a website can be done in 10 simple steps:\" -n 512"
|
||||
echo " --bench (-b): Benchmark the performance of the inference for various parameters."
|
||||
echo " ex: -m model.gguf"
|
||||
echo " --perplexity (-p): Measure the perplexity of a model over a given text."
|
||||
echo " ex: -m model.gguf -f file.txt"
|
||||
echo " --convert (-c): Convert a llama model into ggml"
|
||||
echo " ex: --outtype f16 \"/models/7B/\" "
|
||||
echo " --quantize (-q): Optimize with quantization process ggml"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ARG UBUNTU_VERSION=jammy
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
@@ -7,7 +7,7 @@ RUN apt update && apt install -y git build-essential cmake wget
|
||||
|
||||
# Install Vulkan SDK and cURL
|
||||
RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \
|
||||
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list && \
|
||||
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \
|
||||
apt update -y && \
|
||||
apt-get install -y vulkan-sdk libcurl4-openssl-dev curl
|
||||
|
||||
@@ -34,7 +34,7 @@ RUN mkdir -p /app/full \
|
||||
FROM ubuntu:$UBUNTU_VERSION AS base
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl\
|
||||
&& apt-get install -y libgomp1 curl libvulkan-dev \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
@@ -55,8 +55,9 @@ RUN apt-get update \
|
||||
git \
|
||||
python3 \
|
||||
python3-pip \
|
||||
&& pip install --upgrade pip setuptools wheel \
|
||||
&& pip install -r requirements.txt \
|
||||
python3-wheel \
|
||||
&& pip install --break-system-packages --upgrade setuptools \
|
||||
&& pip install --break-system-packages -r requirements.txt \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -40,3 +40,11 @@ indent_style = tab
|
||||
[examples/cvector-generator/*.txt]
|
||||
trim_trailing_whitespace = unset
|
||||
insert_final_newline = unset
|
||||
|
||||
[models/templates/*.jinja]
|
||||
indent_style = unset
|
||||
indent_size = unset
|
||||
end_of_line = unset
|
||||
charset = unset
|
||||
trim_trailing_whitespace = unset
|
||||
insert_final_newline = unset
|
||||
|
||||
392
.github/workflows/build.yml
vendored
392
.github/workflows/build.yml
vendored
@@ -43,6 +43,12 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: macOS-latest-cmake-arm64
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
@@ -53,15 +59,14 @@ jobs:
|
||||
id: cmake_build
|
||||
run: |
|
||||
sysctl -a
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. \
|
||||
cmake -B build \
|
||||
-DCMAKE_BUILD_RPATH="@loader_path" \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_CURL=ON \
|
||||
-DGGML_METAL_USE_BF16=ON \
|
||||
-DGGML_METAL_EMBED_LIBRARY=ON \
|
||||
-DGGML_RPC=ON
|
||||
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
@@ -107,6 +112,12 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: macOS-latest-cmake-x64
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
@@ -120,6 +131,7 @@ jobs:
|
||||
# Metal is disabled due to intermittent failures with Github runners not having a GPU:
|
||||
# https://github.com/ggerganov/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313
|
||||
cmake -B build \
|
||||
-DCMAKE_BUILD_RPATH="@loader_path" \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_CURL=ON \
|
||||
-DGGML_METAL=OFF \
|
||||
@@ -160,8 +172,8 @@ jobs:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip
|
||||
name: llama-bin-macos-x64.zip
|
||||
|
||||
ubuntu-latest-cmake:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-cpu-cmake:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -170,6 +182,12 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-cpu-cmake
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
@@ -179,10 +197,11 @@ jobs:
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_RPC=ON
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
cmake -B build \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_CURL=ON \
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
@@ -244,6 +263,12 @@ jobs:
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-latest-cmake-sanitizer-${{ matrix.sanitizer }}
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
@@ -254,19 +279,52 @@ jobs:
|
||||
id: cmake_build
|
||||
if: ${{ matrix.sanitizer != 'THREAD' }}
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
|
||||
cmake --build . --config ${{ matrix.build_type }} -j $(nproc)
|
||||
cmake -B build \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
|
||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
|
||||
|
||||
- name: Build (no OpenMP)
|
||||
id: cmake_build_no_openmp
|
||||
if: ${{ matrix.sanitizer == 'THREAD' }}
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||
-DGGML_OPENMP=OFF
|
||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
run: |
|
||||
cd build
|
||||
ctest -L main --verbose --timeout 900
|
||||
|
||||
ubuntu-latest-llguidance:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} -DGGML_OPENMP=OFF
|
||||
cmake --build . --config ${{ matrix.build_type }} -j $(nproc)
|
||||
cmake .. \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_LLGUIDANCE=ON
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
@@ -284,6 +342,12 @@ jobs:
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-latest-cmake-rpc
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
@@ -293,10 +357,9 @@ jobs:
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DGGML_RPC=ON ..
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
cmake -B build \
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
@@ -312,6 +375,12 @@ jobs:
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-22-cmake-vulkan
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
@@ -323,16 +392,16 @@ jobs:
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DGGML_VULKAN=ON ..
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
cmake -B build \
|
||||
-DGGML_VULKAN=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
run: |
|
||||
cd build
|
||||
ctest -L main --verbose --timeout 900
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 1800
|
||||
|
||||
ubuntu-22-cmake-hip:
|
||||
runs-on: ubuntu-22.04
|
||||
@@ -349,16 +418,27 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-22-cmake-hip
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Build with native CMake HIP support
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -S . -DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" -DGGML_HIP=ON
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGGML_HIP=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
- name: Build with legacy HIP support
|
||||
id: cmake_build_legacy_hip
|
||||
run: |
|
||||
cmake -B build2 -S . -DCMAKE_C_COMPILER=hipcc -DCMAKE_CXX_COMPILER=hipcc -DGGML_HIP=ON
|
||||
cmake -B build2 -S . \
|
||||
-DCMAKE_C_COMPILER=hipcc \
|
||||
-DCMAKE_CXX_COMPILER=hipcc \
|
||||
-DGGML_HIP=ON
|
||||
cmake --build build2 --config Release -j $(nproc)
|
||||
|
||||
ubuntu-22-cmake-musa:
|
||||
@@ -376,10 +456,17 @@ jobs:
|
||||
apt-get update
|
||||
apt-get install -y build-essential git cmake libcurl4-openssl-dev
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-22-cmake-musa
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Build with native CMake MUSA support
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -S . -DGGML_MUSA=ON
|
||||
cmake -B build -S . \
|
||||
-DGGML_MUSA=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
ubuntu-22-cmake-sycl:
|
||||
@@ -414,14 +501,21 @@ jobs:
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-22-cmake-sycl
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ..
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
cmake -B build \
|
||||
-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
ubuntu-22-cmake-sycl-fp16:
|
||||
runs-on: ubuntu-22.04
|
||||
@@ -455,47 +549,22 @@ jobs:
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-22-cmake-sycl-fp16
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON ..
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
|
||||
# TODO: build with GGML_METAL=OFF because test-backend-ops fail on "Apple Paravirtual device" and I don't know
|
||||
# how to debug it.
|
||||
# ref: https://github.com/ggerganov/llama.cpp/actions/runs/7132125951/job/19422043567?pr=4359#step:5:6584
|
||||
# would be great if we fix these
|
||||
macOS-latest-cmake:
|
||||
runs-on: macos-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
sysctl -a
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL=OFF ..
|
||||
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
run: |
|
||||
cd build
|
||||
ctest -L main --verbose --timeout 900
|
||||
cmake -B build \
|
||||
-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
macOS-latest-cmake-ios:
|
||||
runs-on: macos-latest
|
||||
@@ -505,6 +574,12 @@ jobs:
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: macOS-latest-cmake-ios
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
@@ -515,9 +590,7 @@ jobs:
|
||||
id: cmake_build
|
||||
run: |
|
||||
sysctl -a
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -G Xcode .. \
|
||||
cmake -B build -G Xcode \
|
||||
-DGGML_METAL_USE_BF16=ON \
|
||||
-DGGML_METAL_EMBED_LIBRARY=ON \
|
||||
-DLLAMA_BUILD_EXAMPLES=OFF \
|
||||
@@ -526,7 +599,7 @@ jobs:
|
||||
-DCMAKE_SYSTEM_NAME=iOS \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
|
||||
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
|
||||
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
|
||||
|
||||
macOS-latest-cmake-tvos:
|
||||
runs-on: macos-latest
|
||||
@@ -536,6 +609,12 @@ jobs:
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: macOS-latest-cmake-tvos
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
@@ -546,9 +625,7 @@ jobs:
|
||||
id: cmake_build
|
||||
run: |
|
||||
sysctl -a
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -G Xcode .. \
|
||||
cmake -B build -G Xcode \
|
||||
-DGGML_METAL_USE_BF16=ON \
|
||||
-DGGML_METAL_EMBED_LIBRARY=ON \
|
||||
-DLLAMA_BUILD_EXAMPLES=OFF \
|
||||
@@ -557,7 +634,7 @@ jobs:
|
||||
-DCMAKE_SYSTEM_NAME=tvOS \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
|
||||
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
|
||||
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
|
||||
|
||||
macOS-latest-swift:
|
||||
runs-on: macos-latest
|
||||
@@ -571,6 +648,12 @@ jobs:
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: macOS-latest-swift
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
@@ -581,17 +664,15 @@ jobs:
|
||||
id: cmake_build
|
||||
run: |
|
||||
sysctl -a
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -G Xcode .. \
|
||||
cmake -B build -G Xcode \
|
||||
-DGGML_METAL_USE_BF16=ON \
|
||||
-DGGML_METAL_EMBED_LIBRARY=ON \
|
||||
-DLLAMA_BUILD_EXAMPLES=OFF \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64"
|
||||
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
sudo cmake --install . --config Release
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
sudo cmake --install build --config Release
|
||||
|
||||
- name: xcodebuild for swift package
|
||||
id: xcodebuild
|
||||
@@ -612,6 +693,13 @@ jobs:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: windows-msys2
|
||||
variant: sccache
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Setup ${{ matrix.sys }}
|
||||
uses: msys2/setup-msys2@v2
|
||||
with:
|
||||
@@ -619,6 +707,7 @@ jobs:
|
||||
msystem: ${{matrix.sys}}
|
||||
install: >-
|
||||
base-devel
|
||||
git
|
||||
mingw-w64-${{matrix.env}}-toolchain
|
||||
mingw-w64-${{matrix.env}}-cmake
|
||||
mingw-w64-${{matrix.env}}-openblas
|
||||
@@ -679,6 +768,13 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: windows-latest-cmake-${{ matrix.build }}
|
||||
variant: sccache
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Clone Kompute submodule
|
||||
id: clone_kompute
|
||||
if: ${{ matrix.build == 'kompute-x64' }}
|
||||
@@ -718,21 +814,19 @@ jobs:
|
||||
run: |
|
||||
git clone https://github.com/KhronosGroup/OpenCL-Headers
|
||||
cd OpenCL-Headers
|
||||
mkdir build && cd build
|
||||
cmake .. `
|
||||
cmake -B build `
|
||||
-DBUILD_TESTING=OFF `
|
||||
-DOPENCL_HEADERS_BUILD_TESTING=OFF `
|
||||
-DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF `
|
||||
-DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release"
|
||||
cmake --build . --target install
|
||||
cmake --build build --target install
|
||||
git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader
|
||||
cd OpenCL-ICD-Loader
|
||||
mkdir build-arm64-release && cd build-arm64-release
|
||||
cmake .. `
|
||||
cmake -B build-arm64-release `
|
||||
-A arm64 `
|
||||
-DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" `
|
||||
-DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release"
|
||||
cmake --build . --target install --config release
|
||||
cmake --build build-arm64-release --target install --config release
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -817,6 +911,8 @@ jobs:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install dependencies
|
||||
env:
|
||||
@@ -825,9 +921,21 @@ jobs:
|
||||
apt update
|
||||
apt install -y cmake build-essential ninja-build libgomp1 git
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-latest-cmake-cuda
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Build with CMake
|
||||
run: |
|
||||
cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=89-real -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined -DLLAMA_FATAL_WARNINGS=ON
|
||||
cmake -S . -B build -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_CUDA_ARCHITECTURES=89-real \
|
||||
-DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DGGML_CUDA=ON
|
||||
cmake --build build
|
||||
|
||||
windows-2019-cmake-cuda:
|
||||
@@ -845,6 +953,13 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ${{ github.job }}-${{ matrix.cuda }}-${{ matrix.build }}
|
||||
variant: sccache
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Install Cuda Toolkit 11.7
|
||||
if: ${{ matrix.cuda == '11.7' }}
|
||||
run: |
|
||||
@@ -901,11 +1016,6 @@ jobs:
|
||||
echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
|
||||
echo "CUDA_PATH_V12_4=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
|
||||
|
||||
- name: Install ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2
|
||||
with:
|
||||
key: ${{ github.job }}-${{ matrix.cuda }}-${{ matrix.build }}
|
||||
|
||||
- name: Install Ninja
|
||||
id: install_ninja
|
||||
run: |
|
||||
@@ -916,7 +1026,11 @@ jobs:
|
||||
shell: cmd
|
||||
run: |
|
||||
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
|
||||
cmake -S . -B build -G "Ninja Multi-Config" -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DGGML_RPC=ON
|
||||
cmake -S . -B build -G "Ninja Multi-Config" ^
|
||||
-DLLAMA_BUILD_SERVER=ON ^
|
||||
-DGGML_NATIVE=OFF ^
|
||||
-DGGML_CUDA=ON ^
|
||||
-DGGML_RPC=ON
|
||||
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
|
||||
cmake --build build --config Release -j %NINJA_JOBS% -t ggml
|
||||
cmake --build build --config Release
|
||||
@@ -981,6 +1095,13 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: windows-latest-cmake-sycl
|
||||
variant: sccache
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Install
|
||||
run: |
|
||||
scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL
|
||||
@@ -1060,16 +1181,22 @@ jobs:
|
||||
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
|
||||
|
||||
- name: Install ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ${{ github.job }}
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
|
||||
cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIP=ON -DCMAKE_BUILD_TYPE=Release -DGGML_RPC=ON
|
||||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DGGML_HIP=ON `
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
|
||||
|
||||
windows-latest-cmake-hip-release:
|
||||
@@ -1087,6 +1214,12 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: windows-latest-cmake-hip-release
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Install
|
||||
id: depends
|
||||
run: |
|
||||
@@ -1107,7 +1240,13 @@ jobs:
|
||||
run: |
|
||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
|
||||
cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIP=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS=${{ matrix.gpu_target }} -DGGML_RPC=ON
|
||||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DAMDGPU_TARGETS=${{ matrix.gpu_target }} `
|
||||
-DGGML_HIP=ON `
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
|
||||
md "build\bin\rocblas\library\"
|
||||
cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\"
|
||||
@@ -1149,9 +1288,7 @@ jobs:
|
||||
id: cmake_build
|
||||
run: |
|
||||
sysctl -a
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -G Xcode .. \
|
||||
cmake -B build -G Xcode \
|
||||
-DGGML_METAL_USE_BF16=ON \
|
||||
-DGGML_METAL_EMBED_LIBRARY=ON \
|
||||
-DLLAMA_BUILD_EXAMPLES=OFF \
|
||||
@@ -1160,8 +1297,8 @@ jobs:
|
||||
-DCMAKE_SYSTEM_NAME=iOS \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
|
||||
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
|
||||
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
|
||||
sudo cmake --install . --config Release
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
|
||||
sudo cmake --install build --config Release
|
||||
|
||||
- name: xcodebuild for swift package
|
||||
id: xcodebuild
|
||||
@@ -1178,6 +1315,12 @@ jobs:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: android-build
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Set up JDK
|
||||
uses: actions/setup-java@v3
|
||||
with:
|
||||
@@ -1201,8 +1344,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
needs:
|
||||
- ubuntu-latest-cmake
|
||||
- macOS-latest-cmake
|
||||
- ubuntu-cpu-cmake
|
||||
- windows-latest-cmake
|
||||
- windows-2019-cmake-cuda
|
||||
- windows-latest-cmake-hip-release
|
||||
@@ -1216,6 +1358,12 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: release
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Determine tag name
|
||||
id: tag
|
||||
shell: bash
|
||||
@@ -1461,3 +1609,37 @@ jobs:
|
||||
# popd
|
||||
# emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
# make
|
||||
|
||||
openEuler-latest-cmake-cann:
|
||||
if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -el {0}
|
||||
runs-on: ubuntu-24.04-arm
|
||||
strategy:
|
||||
matrix:
|
||||
cann:
|
||||
- '8.0.rc3.beta1-910b-openeuler22.03-py3.10'
|
||||
device:
|
||||
- 'ascend910b3'
|
||||
build:
|
||||
- 'Release'
|
||||
container: ascendai/cann:${{ matrix.cann }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
yum update -y
|
||||
yum install -y git gcc gcc-c++ make cmake
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH}
|
||||
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
||||
-DGGML_CANN=on \
|
||||
-DSOC_TYPE=${{ matrix.device }}
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
2
.github/workflows/close-issue.yml
vendored
2
.github/workflows/close-issue.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
with:
|
||||
exempt-issue-labels: "refactor,help wanted,good first issue,research,bug"
|
||||
exempt-issue-labels: "refactor,help wanted,good first issue,research,bug,roadmap"
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
|
||||
3
.github/workflows/docker.yml
vendored
3
.github/workflows/docker.yml
vendored
@@ -28,10 +28,11 @@ jobs:
|
||||
push_to_registry:
|
||||
name: Push Docker image to Docker Hub
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
COMMIT_SHA: ${{ github.sha }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
# Multi-stage build
|
||||
|
||||
2
.github/workflows/server.yml
vendored
2
.github/workflows/server.yml
vendored
@@ -205,7 +205,7 @@ jobs:
|
||||
run: |
|
||||
cd examples/server/tests
|
||||
$env:PYTHONIOENCODING = ":replace"
|
||||
pytest -v -x
|
||||
pytest -v -x -m "not slow"
|
||||
|
||||
- name: Slow tests
|
||||
id: server_integration_tests_slow
|
||||
|
||||
83
AUTHORS
83
AUTHORS
@@ -1,4 +1,4 @@
|
||||
# date: Thu Nov 28 20:46:15 EET 2024
|
||||
# date: Tue Feb 4 13:04:05 EET 2025
|
||||
# this file is auto-generated by scripts/gen-authors.sh
|
||||
|
||||
0cc4m <picard12@live.de>
|
||||
@@ -20,6 +20,8 @@ Adithya Balaji <adithya.b94@gmail.com>
|
||||
AdithyanI <adithyan.i4internet@gmail.com>
|
||||
Adrian <smith.adriane@gmail.com>
|
||||
Adrian Hesketh <a-h@users.noreply.github.com>
|
||||
Adrien Gallouët <adrien@gallouet.fr>
|
||||
Adrien Gallouët <angt@huggingface.co>
|
||||
Ahmad Tameem <113388789+Tameem-10xE@users.noreply.github.com>
|
||||
Ahmet Zeer <ahmed.zeer@std.yildiz.edu.tr>
|
||||
AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com>
|
||||
@@ -55,6 +57,7 @@ Ananta Bastola <anantarajbastola@gmail.com>
|
||||
Anas Ahouzi <112881240+aahouzi@users.noreply.github.com>
|
||||
András Salamon <ott2@users.noreply.github.com>
|
||||
Andreas (Andi) Kunar <andreask@msn.com>
|
||||
Andreas Kieslinger <47689530+aendk@users.noreply.github.com>
|
||||
Andrei <abetlen@gmail.com>
|
||||
Andrew Canis <andrew.canis@gmail.com>
|
||||
Andrew Downing <andrew2085@gmail.com>
|
||||
@@ -91,13 +94,17 @@ Ben Siraphob <bensiraphob@gmail.com>
|
||||
Ben Williams <ben@719ben.com>
|
||||
Benjamin Findley <39356821+Kartoffelsaft@users.noreply.github.com>
|
||||
Benjamin Lecaillon <84293038+blecaillon@users.noreply.github.com>
|
||||
Benson Wong <mostlygeek@gmail.com>
|
||||
Bernat Vadell <hounter.caza@gmail.com>
|
||||
Bernhard M. Wiedemann <githubbmwprimary@lsmod.de>
|
||||
Bert Wagner <github@bertwagner.com>
|
||||
Billel Mokeddem <billel.mokeddem.ml@gmail.com>
|
||||
Bingan <70050083+binganao@users.noreply.github.com>
|
||||
Bjarke Viksøe <164612031+bviksoe@users.noreply.github.com>
|
||||
Bodo Graumann <mail@bodograumann.de>
|
||||
Bono Lv <lvscar@users.noreply.github.com>
|
||||
Borislav Stanimirov <b.stanimirov@abv.bg>
|
||||
Borislav Stanimirov <b@ibob.bg>
|
||||
Branden Butler <bwtbutler@hotmail.com>
|
||||
Brandon Squizzato <35474886+bsquizz@users.noreply.github.com>
|
||||
Brian <mofosyne@gmail.com>
|
||||
@@ -117,6 +124,7 @@ Casey Primozic <casey@cprimozic.net>
|
||||
Casey Primozic <me@ameo.link>
|
||||
CausalLM <148736309+CausalLM@users.noreply.github.com>
|
||||
Cebtenzzre <cebtenzzre@gmail.com>
|
||||
CentricStorm <CentricStorm@users.noreply.github.com>
|
||||
Chad Brewbaker <crb002@gmail.com>
|
||||
Changyeon Kim <cyzero.kim@samsung.com>
|
||||
Chao Jiang <jc19chaoj@zoho.com>
|
||||
@@ -131,12 +139,15 @@ Chris Kuehl <ckuehl@ckuehl.me>
|
||||
Christian Demsar <christian@github.email.demsar.us>
|
||||
Christian Demsar <crasm@git.vczf.us>
|
||||
Christian Falch <875252+chrfalch@users.noreply.github.com>
|
||||
Christian Kastner <ckk@kvr.at>
|
||||
Christian Kögler <ck3d@gmx.de>
|
||||
Christian Köhnenkamp <cvk5@me.com>
|
||||
Christian Zhou-Zheng <59622928+christianazinn@users.noreply.github.com>
|
||||
Christopher Nielsen <62156882+mascguy@users.noreply.github.com>
|
||||
Clark Saben <76020733+csaben@users.noreply.github.com>
|
||||
Clint Herron <hanclinto@gmail.com>
|
||||
Conrad Kramer <conrad@conradkramer.com>
|
||||
Corentin REGAL <corentin.regal@gmail.com>
|
||||
CrispStrobe <154636388+CrispStrobe@users.noreply.github.com>
|
||||
Csaba Kecskemeti <csaba.kecskemeti@gmail.com>
|
||||
Cuong Trinh Manh <nguoithichkhampha@gmail.com>
|
||||
@@ -176,6 +187,7 @@ Dibakar Gope <dibakar.gope@arm.com>
|
||||
Didzis Gosko <didzis@users.noreply.github.com>
|
||||
Diego Devesa <slarengh@gmail.com>
|
||||
Diogo Teles Sant'Anna <diogoteles@google.com>
|
||||
Djip007 <3705339+Djip007@users.noreply.github.com>
|
||||
Djip007 <djip.perois@free.fr>
|
||||
Don Mahurin <dmahurin@users.noreply.github.com>
|
||||
DooWoong Lee (David) <manics99@naver.com>
|
||||
@@ -193,6 +205,7 @@ Edward Taylor <edeetee@gmail.com>
|
||||
Elaine <elaine.zosa@gmail.com>
|
||||
Elbios <141279586+Elbios@users.noreply.github.com>
|
||||
Elton Kola <eltonkola@gmail.com>
|
||||
Emreerdog <34742675+Emreerdog@users.noreply.github.com>
|
||||
Engininja2 <139037756+Engininja2@users.noreply.github.com>
|
||||
Equim <sayaka@ekyu.moe>
|
||||
Eric Curtin <ecurtin@redhat.com>
|
||||
@@ -233,6 +246,7 @@ Fred Douglas <43351173+fredlas@users.noreply.github.com>
|
||||
Frederik Vogel <Schaltfehler@users.noreply.github.com>
|
||||
Gabe Goodhart <gabe.l.hart@gmail.com>
|
||||
Gabe Goodhart <ghart@us.ibm.com>
|
||||
Gaetan Bisson <gaetan@fenua.org>
|
||||
GainLee <perfecter.gen@gmail.com>
|
||||
Galunid <karolek1231456@gmail.com>
|
||||
Gary Linscott <glinscott@gmail.com>
|
||||
@@ -249,6 +263,7 @@ Guillaume "Vermeille" Sanchez <Guillaume.V.Sanchez@gmail.com>
|
||||
Guillaume Wenzek <gwenzek@users.noreply.github.com>
|
||||
Guoliang Hua <32868157+nbcsm@users.noreply.github.com>
|
||||
Guoteng <32697156+SolenoidWGT@users.noreply.github.com>
|
||||
Guspan Tanadi <36249910+guspan-tanadi@users.noreply.github.com>
|
||||
Gustavo Rocha Dias <91472747+gustrd@users.noreply.github.com>
|
||||
Haggai Nuchi <h.nuchi@gmail.com>
|
||||
Halalaluyafail3 <55773281+Halalaluyafail3@users.noreply.github.com>
|
||||
@@ -259,11 +274,13 @@ Haoxiang Fei <tonyfettes@tonyfettes.com>
|
||||
Harald Fernengel <harald.fernengel@here.com>
|
||||
Hatsune Miku <129688334+at8u@users.noreply.github.com>
|
||||
HatsuneMikuUwU33 <173229399+HatsuneMikuUwU33@users.noreply.github.com>
|
||||
Haus1 <haus.xda@gmail.com>
|
||||
Henk Poley <HenkPoley@gmail.com>
|
||||
Henri Vasserman <henv@hot.ee>
|
||||
Henrik Forstén <henrik.forsten@gmail.com>
|
||||
Herman Semenov <GermanAizek@yandex.ru>
|
||||
Hesen Peng <hesen.peng@gmail.com>
|
||||
HimariO <dsfhe49854@gmail.com>
|
||||
Hoang Nguyen <hugo53@users.noreply.github.com>
|
||||
Hong Bo PENG <penghb@cn.ibm.com>
|
||||
Hongyu Ouyang <96765450+casavaca@users.noreply.github.com>
|
||||
@@ -280,6 +297,7 @@ Icecream95 <the.real.icecream95@gmail.com>
|
||||
Ido S <ido.pluto@gmail.com>
|
||||
IgnacioFDM <ignaciofdm@gmail.com>
|
||||
Igor Okulist <okigan@gmail.com>
|
||||
Ihar Hrachyshka <ihrachys@redhat.com>
|
||||
Ikko Eltociear Ashimine <eltociear@gmail.com>
|
||||
Ilya Kurdyukov <59548320+ilyakurdyukov@users.noreply.github.com>
|
||||
Ionoclast Laboratories <brigham@ionoclast.com>
|
||||
@@ -289,12 +307,14 @@ Ivan <nekotekina@gmail.com>
|
||||
Ivan Filipov <159561759+vanaka11@users.noreply.github.com>
|
||||
Ivan Komarov <Ivan.Komarov@dfyz.info>
|
||||
Ivan Stepanov <ivanstepanovftw@gmail.com>
|
||||
JFLFY2255 <JFLFY2255@163.com>
|
||||
JH23X <165871467+JH23X@users.noreply.github.com>
|
||||
Jack Mousseau <jack@software.inc>
|
||||
Jack Mousseau <jmousseau@users.noreply.github.com>
|
||||
JackJollimore <130917767+JackJollimore@users.noreply.github.com>
|
||||
Jaeden Amero <jaeden@patater.com>
|
||||
Jaemin Son <woalsdnd@gmail.com>
|
||||
Jafar Uruç <jafar.uruc@gmail.com>
|
||||
Jag Chadha <jagtesh@gmail.com>
|
||||
Jakub N <jakubniemczyk97@gmail.com>
|
||||
James A Capozzoli <157492257+jac-jim@users.noreply.github.com>
|
||||
@@ -315,6 +335,7 @@ Jeffrey Morgan <jmorganca@gmail.com>
|
||||
Jeffrey Quesnelle <emozilla@nousresearch.com>
|
||||
Jeroen Mostert <jeroen.mostert@cm.com>
|
||||
Jesse Jojo Johnson <williamsaintgeorge@gmail.com>
|
||||
Jett Janiak <jettjaniak@gmail.com>
|
||||
Jeximo <jeximo@gmail.com>
|
||||
Jhen-Jie Hong <iainst0409@gmail.com>
|
||||
Jiahao Li <liplus17@163.com>
|
||||
@@ -343,6 +364,7 @@ Josh Ramer <josh.ramer@icloud.com>
|
||||
Joyce <joycebrum@google.com>
|
||||
Juan Calderon-Perez <835733+gaby@users.noreply.github.com>
|
||||
Judd <foldl@users.noreply.github.com>
|
||||
Juk Armstrong <69222624+jukofyork@users.noreply.github.com>
|
||||
Julius Arkenberg <arki05@users.noreply.github.com>
|
||||
Jun Hee Yoo <contact.jhyoo@gmail.com>
|
||||
Jun Jie <71215065+junnjiee16@users.noreply.github.com>
|
||||
@@ -357,6 +379,7 @@ Justine Tunney <jtunney@mozilla.com>
|
||||
Juuso Alasuutari <juuso.alasuutari@gmail.com>
|
||||
KASR <karim.asrih@gmail.com>
|
||||
Kamil Tomšík <info@tomsik.cz>
|
||||
Karol Kontny <82021046+kkontny@users.noreply.github.com>
|
||||
Karsten Weiss <knweiss@gmail.com>
|
||||
Karthick <j.karthic2004@gmail.com>
|
||||
Karthik Kumar Viswanathan <195178+guilt@users.noreply.github.com>
|
||||
@@ -376,6 +399,7 @@ Kolen Cheung <ickc@users.noreply.github.com>
|
||||
Konstantin Herud <konstantin.herud@denkbares.com>
|
||||
Konstantin Zhuravlyov <konstantin.zhuravlyov@amd.com>
|
||||
Kunshang Ji <kunshang.ji@intel.com>
|
||||
Kyle Bruene <KyleBruene@users.noreply.github.com>
|
||||
Kyle Liang <liangmanlai@gmail.com>
|
||||
Kyle Mistele <kyle@mistele.com>
|
||||
Kylin <56434533+KyL0N@users.noreply.github.com>
|
||||
@@ -394,6 +418,7 @@ Liu Jia <jia3.liu@intel.com>
|
||||
LoganDark <github@logandark.mozmail.com>
|
||||
Loïc Carrère <loic.carrere@gmail.com>
|
||||
LostRuins <39025047+LostRuins@users.noreply.github.com>
|
||||
LostRuins Concedo <39025047+LostRuins@users.noreply.github.com>
|
||||
Luciano <lucianostrika44@gmail.com>
|
||||
Luo Tian <lt@basecity.com>
|
||||
Lyle Dean <dean@lyle.dev>
|
||||
@@ -423,6 +448,7 @@ MasterYi1024 <39848311+MasterYi1024@users.noreply.github.com>
|
||||
Mateusz Charytoniuk <mateusz.charytoniuk@protonmail.com>
|
||||
Matheus C. França <matheus-catarino@hotmail.com>
|
||||
Matheus Gabriel Alves Silva <matheusgasource@gmail.com>
|
||||
Mathieu Baudier <mbaudier@argeo.org>
|
||||
Mathieu Geli <mathieu.geli@gmail.com>
|
||||
Mathieu Nayrolles <MathieuNls@users.noreply.github.com>
|
||||
Mathijs Henquet <mathijs.henquet@gmail.com>
|
||||
@@ -444,6 +470,7 @@ Meng, Hengyu <hengyu.meng@intel.com>
|
||||
Mengqing Cao <cmq0113@163.com>
|
||||
Merrick Christensen <merrick.christensen@gmail.com>
|
||||
Michael Coppola <m18coppola@gmail.com>
|
||||
Michael Engel <mengel@redhat.com>
|
||||
Michael Francis <edude03@gmail.com>
|
||||
Michael Hueschen <m@mhueschen.dev>
|
||||
Michael Kesper <mkesper@schokokeks.org>
|
||||
@@ -452,7 +479,9 @@ Michael Podvitskiy <podvitskiymichael@gmail.com>
|
||||
Michael Potter <NanoTekGuy@Gmail.com>
|
||||
Michael de Gans <michael.john.degans@gmail.com>
|
||||
Michaël de Vries <vriesdemichael@gmail.com>
|
||||
Michał Moskal <michal@moskal.me>
|
||||
Michał Tuszyński <srgtuszy@gmail.com>
|
||||
Michelle Tan <41475767+MichelleTanPY@users.noreply.github.com>
|
||||
Mihai <mihai.chirculescu@yahoo.com>
|
||||
Mike <ytianhui2004@gmail.com>
|
||||
Mikko Juola <mikjuo@gmail.com>
|
||||
@@ -477,6 +506,7 @@ Neo Zhang <14088817+arthw@users.noreply.github.com>
|
||||
Neo Zhang <zhang.jianyu@outlook.com>
|
||||
Neo Zhang Jianyu <jianyu.zhang@intel.com>
|
||||
Neuman Vong <neuman.vong@gmail.com>
|
||||
NeverLucky <92274250+nvrxq@users.noreply.github.com>
|
||||
Nexes the Old <124105151+Nexesenex@users.noreply.github.com>
|
||||
Nexesenex <124105151+Nexesenex@users.noreply.github.com>
|
||||
Niall Coates <1349685+Niall-@users.noreply.github.com>
|
||||
@@ -484,11 +514,15 @@ Nicholai Tukanov <nicholaitukanov@gmail.com>
|
||||
Nico Bosshard <nico@bosshome.ch>
|
||||
Nicolai Weitkemper <kontakt@nicolaiweitkemper.de>
|
||||
Nicolás Pérez <nicolas_perez@brown.edu>
|
||||
Nicolò Scipione <nicolo.scipione@codeplay.com>
|
||||
Nigel Bosch <pnigelb@gmail.com>
|
||||
Nikita Sarychev <42014488+sARY77@users.noreply.github.com>
|
||||
Niklas Korz <niklas@niklaskorz.de>
|
||||
NikolaiLyssogor <59844691+NikolaiLyssogor@users.noreply.github.com>
|
||||
Nikolaos Pothitos <pothitos@di.uoa.gr>
|
||||
Nikolas <127742645+nneubacher@users.noreply.github.com>
|
||||
Nindaleth <Nindaleth@users.noreply.github.com>
|
||||
Nuno <rare-magma@posteo.eu>
|
||||
OSecret <135510162+OLSecret@users.noreply.github.com>
|
||||
Oleksandr Nikitin <oleksandr@tvori.info>
|
||||
Oleksii Maryshchenko <oleksii.maryshchenko@gmail.com>
|
||||
@@ -504,6 +538,7 @@ Pavel Zloi <github.com@drteam.rocks>
|
||||
Pavol Rusnak <pavol@rusnak.io>
|
||||
Paweł Wodnicki <151604+32bitmicro@users.noreply.github.com>
|
||||
Pedro Cuenca <pedro@huggingface.co>
|
||||
Peter <peter277@users.noreply.github.com>
|
||||
Peter Sugihara <peter@campsh.com>
|
||||
Phil H <5756783+phiharri@users.noreply.github.com>
|
||||
Philip Taron <philip.taron@gmail.com>
|
||||
@@ -529,9 +564,12 @@ Rand Xie <randxiexyy29@gmail.com>
|
||||
Randall Fitzgerald <randall@dasaku.net>
|
||||
Random Fly <renfei8@live.cn>
|
||||
Reinforce-II <fate@eastal.com>
|
||||
Rémy Oudompheng <oudomphe@phare.normalesup.org>
|
||||
Ren Xuancheng <jklj077@users.noreply.github.com>
|
||||
Rene Leonhardt <65483435+reneleonhardt@users.noreply.github.com>
|
||||
Reza Kakhki <rezakakhki.de@gmail.com>
|
||||
RhinoDevel <RhinoDevel@users.noreply.github.com>
|
||||
Riccardo Orlando <Riccorl@users.noreply.github.com>
|
||||
Riceball LEE <snowyu.lee@gmail.com>
|
||||
Rich Dougherty <rich@rd.nz>
|
||||
Richard Kiss <him@richardkiss.com>
|
||||
@@ -544,6 +582,8 @@ Riley Stewart <ristew@users.noreply.github.com>
|
||||
Rinne <AsakusaRinne@gmail.com>
|
||||
Rinne <liu_yaohui1998@126.com>
|
||||
Robert Brisita <986796+rbrisita@users.noreply.github.com>
|
||||
Robert Collins <roberto.tomas.cuentas@gmail.com>
|
||||
Robert Ormandi <52251610+ormandi@users.noreply.github.com>
|
||||
Robert Sung-wook Shin <edp1096@users.noreply.github.com>
|
||||
Robey Holderith <robey@flaminglunchbox.net>
|
||||
Robyn <robyngraf@users.noreply.github.com>
|
||||
@@ -559,7 +599,9 @@ Roni <sulpher@gmx.net>
|
||||
Ronny Brendel <ronnybrendel@gmail.com>
|
||||
Ronsor <ronsor@ronsor.pw>
|
||||
Rowan Hart <rowanbhart@gmail.com>
|
||||
Ruan <47767371+ruanych@users.noreply.github.com>
|
||||
Ruchira Hasaranga <ruchira66@gmail.com>
|
||||
Rudi Servo <rudiservo@gmail.com>
|
||||
Ruixin Huang <18860020911@163.com>
|
||||
Rune <43761327+Rune-AI@users.noreply.github.com>
|
||||
RunningLeon <maningsheng@sensetime.com>
|
||||
@@ -623,12 +665,14 @@ Steven Roussey <sroussey@gmail.com>
|
||||
Steward Garcia <57494570+FSSRepo@users.noreply.github.com>
|
||||
StrangeBytesDev <141275258+StrangeBytesDev@users.noreply.github.com>
|
||||
Suaj Carrot <72162667+SuajCarrot@users.noreply.github.com>
|
||||
Sukriti Sharma <Ssukriti@users.noreply.github.com>
|
||||
SuperUserNameMan <yoann@terminajones.com>
|
||||
Sutou Kouhei <kou@cozmixng.org>
|
||||
Tai Duc Nguyen <taiducnguyen.drexel@gmail.com>
|
||||
Taikono-Himazin <kazu@po.harenet.ne.jp>
|
||||
Tameem <113388789+AhmadTameem@users.noreply.github.com>
|
||||
Tamotsu Takahashi <ttakah+github@gmail.com>
|
||||
Tei Home <taiteitonghome@proton.me>
|
||||
Thái Hoàng Tâm <75922889+RoyalHeart@users.noreply.github.com>
|
||||
Thatcher Chamberlin <j.thatcher.c@gmail.com>
|
||||
Theia Vogel <theia@vgel.me>
|
||||
@@ -640,6 +684,7 @@ Tim Miller <drasticactions@users.noreply.github.com>
|
||||
Tim Wang <overocean@gmail.com>
|
||||
Timmy Knight <r2d2fish@gmail.com>
|
||||
Timothy Cronin <40186632+4imothy@users.noreply.github.com>
|
||||
Ting Lou <louting@189.cn>
|
||||
Ting Lou <ting.lou@gmail.com>
|
||||
Ting Sun <suntcrick@gmail.com>
|
||||
Tobias Lütke <tobi@shopify.com>
|
||||
@@ -661,6 +706,7 @@ Uzo Nweke <uzoechi@gmail.com>
|
||||
Vaibhav Srivastav <vaibhavs10@gmail.com>
|
||||
Val Kharitonov <mail@kharvd.com>
|
||||
Valentin Konovalov <valle.ketsujin@gmail.com>
|
||||
Valentin Mamedov <45292985+Inf1delis@users.noreply.github.com>
|
||||
Valentyn Bezshapkin <61702053+valentynbez@users.noreply.github.com>
|
||||
Vali Malinoiu <0x4139@gmail.com>
|
||||
Victor Nogueira <felladrin@gmail.com>
|
||||
@@ -673,13 +719,17 @@ Vladimir Malyutin <first-leon@yandex.ru>
|
||||
Vladimir Zorin <vladimir@deviant.guru>
|
||||
VoidIsVoid <343750470@qq.com>
|
||||
Volodymyr Vitvitskyi <72226+signalpillar@users.noreply.github.com>
|
||||
Wang Qin <37098874+wangqin0@users.noreply.github.com>
|
||||
Wang Ran (汪然) <wangr@smail.nju.edu.cn>
|
||||
WangHaoranRobin <56047610+WangHaoranRobin@users.noreply.github.com>
|
||||
Weird Constructor <weirdconstructor@gmail.com>
|
||||
Welby Seely <welbyseely@gmail.com>
|
||||
Wentai Zhang <rchardx@gmail.com>
|
||||
WillCorticesAI <150854901+WillCorticesAI@users.noreply.github.com>
|
||||
William Tambellini <william.tambellini@gmail.com>
|
||||
William Tambellini <wtambellini@sdl.com>
|
||||
Willy Tarreau <w@1wt.eu>
|
||||
Woof Dog <197125663+woof-dog@users.noreply.github.com>
|
||||
Wouter <9594229+DifferentialityDevelopment@users.noreply.github.com>
|
||||
Wu Jian Ping <wujjpp@hotmail.com>
|
||||
Wu Jian Ping <wujp@greatld.com>
|
||||
@@ -692,6 +742,7 @@ Xie Yanbo <xieyanbo@gmail.com>
|
||||
Xingchen Song(宋星辰) <xingchensong1996@163.com>
|
||||
Xinpeng Dou <81913537+Dou-Git@users.noreply.github.com>
|
||||
Xuan Son Nguyen <thichthat@gmail.com>
|
||||
Xuan-Son Nguyen <thichthat@gmail.com>
|
||||
Yaiko <elyaiko@hotmail.com>
|
||||
Yann Follet <131855179+YannFollet@users.noreply.github.com>
|
||||
Yaroslav <yaroslav.yashin@me.com>
|
||||
@@ -702,7 +753,9 @@ Yoshi Suhara <y.suhara@gmail.com>
|
||||
Yoshi Suhara <ysuhara@nvidia.com>
|
||||
Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
|
||||
Yueh-Po Peng <94939112+y10ab1@users.noreply.github.com>
|
||||
Yüg <eugeniosegalaweb@gmail.com>
|
||||
Yui <dev@sleepyyui.com>
|
||||
Yun Dou <dixyes@gmail.com>
|
||||
Yuri Khrustalev <ykhrustalev@users.noreply.github.com>
|
||||
Yusuf Kağan Hanoğlu <hanoglu@yahoo.com>
|
||||
Yuval Peled <31162840+Yuval-Peled@users.noreply.github.com>
|
||||
@@ -714,18 +767,23 @@ Zhang Peiyuan <a1286225768@gmail.com>
|
||||
Zheng.Deng <32841220+dengzheng-cloud@users.noreply.github.com>
|
||||
Zhenwei Jin <109658203+kylo5aby@users.noreply.github.com>
|
||||
Zhiyuan Li <lizhiyuan@uniartisan.com>
|
||||
Zhiyuan Li <uniartisan2017@gmail.com>
|
||||
ZhouYuChen <zhouyuchen@naver.com>
|
||||
Ziad Ben Hadj-Alouane <zied.benhadjalouane@gmail.com>
|
||||
Ziang Wu <97337387+ZiangWu-77@users.noreply.github.com>
|
||||
Zsapi <martin1.zsapka@gmail.com>
|
||||
a-n-n-a-l-e-e <150648636+a-n-n-a-l-e-e@users.noreply.github.com>
|
||||
a3sh <38979186+A3shTnT@users.noreply.github.com>
|
||||
adel boussaken <netdur@gmail.com>
|
||||
afrideva <95653597+afrideva@users.noreply.github.com>
|
||||
ag2s20150909 <19373730+ag2s20150909@users.noreply.github.com>
|
||||
agray3 <agray3@users.noreply.github.com>
|
||||
akawrykow <142945436+akawrykow@users.noreply.github.com>
|
||||
alek3y <44779186+alek3y@users.noreply.github.com>
|
||||
alexpinel <93524949+alexpinel@users.noreply.github.com>
|
||||
alonfaraj <alonfaraj@gmail.com>
|
||||
alwqx <kenan3015@gmail.com>
|
||||
amd-dwang <dong.wang@amd.com>
|
||||
amd-lalithnc <lalithnc@amd.com>
|
||||
amritahs-ibm <amritahs@linux.vnet.ibm.com>
|
||||
andrijdavid <david@geek.mg>
|
||||
@@ -737,6 +795,7 @@ arch-btw <57669023+arch-btw@users.noreply.github.com>
|
||||
arcrank <arcrank@gmail.com>
|
||||
ardfork <134447697+ardfork@users.noreply.github.com>
|
||||
arlo-phoenix <140345165+arlo-phoenix@users.noreply.github.com>
|
||||
aryantandon01 <80969509+aryantandon01@users.noreply.github.com>
|
||||
at8u <129688334+at8u@users.noreply.github.com>
|
||||
automaticcat <daogiatuank54@gmail.com>
|
||||
awatuna <23447591+awatuna@users.noreply.github.com>
|
||||
@@ -751,12 +810,14 @@ bryanSwk <93190252+bryanSwk@users.noreply.github.com>
|
||||
bsilvereagle <bsilvereagle@users.noreply.github.com>
|
||||
bssrdf <merlintiger@hotmail.com>
|
||||
byte-6174 <88070277+byte-6174@users.noreply.github.com>
|
||||
cduk <19917266+cduk@users.noreply.github.com>
|
||||
cebtenzzre <cebtenzzre@gmail.com>
|
||||
chaihahaha <chai836275709@gmail.com>
|
||||
chiranko <96988916+chiranko@users.noreply.github.com>
|
||||
clibdev <52199778+clibdev@users.noreply.github.com>
|
||||
clyang <clyang@clyang.net>
|
||||
cocktailpeanut <121128867+cocktailpeanut@users.noreply.github.com>
|
||||
codezjx <code.zjx@gmail.com>
|
||||
coezbek <c.oezbek@gmail.com>
|
||||
comex <comexk@gmail.com>
|
||||
compilade <113953597+compilade@users.noreply.github.com>
|
||||
@@ -780,14 +841,17 @@ drbh <david.richard.holtz@gmail.com>
|
||||
ds5t5 <145942675+ds5t5@users.noreply.github.com>
|
||||
dylan <canardleteer@users.noreply.github.com>
|
||||
eastriver <lee@eastriver.dev>
|
||||
ebraminio <ebrahim@gnu.org>
|
||||
ebraminio <ebraminio@gmail.com>
|
||||
eiery <19350831+eiery@users.noreply.github.com>
|
||||
eric8607242 <e0928021388@gmail.com>
|
||||
fairydreaming <166155368+fairydreaming@users.noreply.github.com>
|
||||
fengerhu1 <2748250768@qq.com>
|
||||
fj-y-saito <85871716+fj-y-saito@users.noreply.github.com>
|
||||
fraxy-v <65565042+fraxy-v@users.noreply.github.com>
|
||||
github-actions[bot] <github-actions[bot]@users.noreply.github.com>
|
||||
gliptic <gliptic@users.noreply.github.com>
|
||||
gn64 <yukikaze.jp@gmail.com>
|
||||
goerch <jhr.walter@t-online.de>
|
||||
grahameth <96447521+grahameth@users.noreply.github.com>
|
||||
gtygo <gtydoit@gmail.com>
|
||||
@@ -812,10 +876,12 @@ icppWorld <124377669+icppWorld@users.noreply.github.com>
|
||||
igarnier <igarnier@protonmail.com>
|
||||
intelmatt <61025942+intelmatt@users.noreply.github.com>
|
||||
iohub <rickyang.pro@gmail.com>
|
||||
issixx <46835150+issixx@users.noreply.github.com>
|
||||
jacobi petrucciani <8117202+jpetrucciani@users.noreply.github.com>
|
||||
jaime-m-p <167997752+jaime-m-p@users.noreply.github.com>
|
||||
jameswu2014 <545426914@qq.com>
|
||||
jdomke <28772296+jdomke@users.noreply.github.com>
|
||||
jiahao su <damow890@gmail.com>
|
||||
jiez <373447296@qq.com>
|
||||
jneem <joeneeman@gmail.com>
|
||||
joecryptotoo <80373433+joecryptotoo@users.noreply.github.com>
|
||||
@@ -828,6 +894,7 @@ junchao-loongson <68935141+junchao-loongson@users.noreply.github.com>
|
||||
jwj7140 <32943891+jwj7140@users.noreply.github.com>
|
||||
k.h.lai <adrian.k.h.lai@outlook.com>
|
||||
kaizau <kaizau@users.noreply.github.com>
|
||||
kallewoof <kalle.alm@gmail.com>
|
||||
kalomaze <66376113+kalomaze@users.noreply.github.com>
|
||||
kang <tpdns9032100@gmail.com>
|
||||
katsu560 <118887472+katsu560@users.noreply.github.com>
|
||||
@@ -835,6 +902,7 @@ kchro3 <62481661+kchro3@users.noreply.github.com>
|
||||
khimaros <me@khimaros.com>
|
||||
kiltyj <kiltyj@gmail.com>
|
||||
klosax <131523366+klosax@users.noreply.github.com>
|
||||
krystiancha <krystian@krystianch.com>
|
||||
kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
|
||||
kunnis <kunnis@users.noreply.github.com>
|
||||
kuronekosaiko <EvanChanJ@163.com>
|
||||
@@ -847,6 +915,8 @@ ldwang <ftgreat@163.com>
|
||||
le.chang <cljs118@126.com>
|
||||
leejet <leejet714@gmail.com>
|
||||
leo-pony <nengjunma@outlook.com>
|
||||
lexasub <lexakopp2212@gmail.com>
|
||||
lhez <quic_lih@quicinc.com>
|
||||
limitedAtonement <limitedAtonement@users.noreply.github.com>
|
||||
liuwei-git <14815172+liuwei-git@users.noreply.github.com>
|
||||
lon <114724657+longregen@users.noreply.github.com>
|
||||
@@ -855,10 +925,13 @@ ltoniazzi <61414566+ltoniazzi@users.noreply.github.com>
|
||||
luoyu-intel <yu.luo@intel.com>
|
||||
m3ndax <adrian.goessl@outlook.com>
|
||||
maddes8cht <55592906+maddes8cht@users.noreply.github.com>
|
||||
mahorozte <41834471+mahorozte@users.noreply.github.com>
|
||||
makomk <makosoft@googlemail.com>
|
||||
manikbhandari <mbbhandarimanik2@gmail.com>
|
||||
maor-ps <154728172+maor-ps@users.noreply.github.com>
|
||||
mashdragon <122402293+mashdragon@users.noreply.github.com>
|
||||
matiaslin <45382001+matiaslin@users.noreply.github.com>
|
||||
matt23654 <matthew.webber@protonmail.com>
|
||||
matteo <matteogeniaccio@yahoo.it>
|
||||
mdrokz <mohammadmunshi@gmail.com>
|
||||
mgroeber9110 <45620825+mgroeber9110@users.noreply.github.com>
|
||||
@@ -868,6 +941,7 @@ mmyjona <jonathan.gonse@gmail.com>
|
||||
momonga <115213907+mmnga@users.noreply.github.com>
|
||||
momonga <146910567+mmngays@users.noreply.github.com>
|
||||
moritzbrantner <31051084+moritzbrantner@users.noreply.github.com>
|
||||
musoles <135031143+musoles@users.noreply.github.com>
|
||||
mzcu <milos.cubrilo@gmail.com>
|
||||
nanahi <130121847+na-na-hi@users.noreply.github.com>
|
||||
ngc92 <7938269+ngc92@users.noreply.github.com>
|
||||
@@ -885,6 +959,7 @@ oobabooga <112222186+oobabooga@users.noreply.github.com>
|
||||
opparco <parco.opaai@gmail.com>
|
||||
ostix360 <55257054+ostix360@users.noreply.github.com>
|
||||
pculliton <phillipculliton@gmail.com>
|
||||
peidaqi <peidaqi@gmail.com>
|
||||
pengxin99 <pengxin.yuan@intel.com>
|
||||
perserk <perserk@gmail.com>
|
||||
piDack <104877312+piDack@users.noreply.github.com>
|
||||
@@ -892,10 +967,12 @@ pmysl <piotr.myslinski@outlook.com>
|
||||
postmasters <namnguyen@google.com>
|
||||
pudepiedj <pudepiedj@gmail.com>
|
||||
qingfengfenga <41416092+qingfengfenga@users.noreply.github.com>
|
||||
qingy1337 <qxli2@students.everettcc.edu>
|
||||
qouoq <qouoq@fastmail.com>
|
||||
qunash <anzoria@gmail.com>
|
||||
rabidcopy <rabidcopy@yahoo.com>
|
||||
rankaiyx <rankaiyx@rankaiyx.com>
|
||||
redbeard <bharrington@alticon.net>
|
||||
rhjdvsgsgks <26178113+rhjdvsgsgks@users.noreply.github.com>
|
||||
rhuddleston <ryan.huddleston@percona.com>
|
||||
rimoliga <53384203+rimoliga@users.noreply.github.com>
|
||||
@@ -912,6 +989,7 @@ sjxx <63994076+ylsdamxssjxxdd@users.noreply.github.com>
|
||||
slaren <2141330+slaren@users.noreply.github.com>
|
||||
slaren <slarengh@gmail.com>
|
||||
snadampal <87143774+snadampal@users.noreply.github.com>
|
||||
someone13574 <81528246+someone13574@users.noreply.github.com>
|
||||
standby24x7 <standby24x7@gmail.com>
|
||||
staviq <staviq@gmail.com>
|
||||
stduhpf <stephduh@live.fr>
|
||||
@@ -931,6 +1009,7 @@ uint256_t <konndennsa@gmail.com>
|
||||
uint256_t <maekawatoshiki1017@gmail.com>
|
||||
unbounded <haakon@likedan.net>
|
||||
uvos <devnull@uvos.xyz>
|
||||
uvos <philipp@uvos.xyz>
|
||||
valiray <133289098+valiray@users.noreply.github.com>
|
||||
vb <vaibhavs10@gmail.com>
|
||||
vik <vikhyatk@gmail.com>
|
||||
@@ -951,6 +1030,7 @@ xaedes <xaedes@googlemail.com>
|
||||
xctan <axunlei@gmail.com>
|
||||
xloem <0xloem@gmail.com>
|
||||
yangli2 <yangli2@gmail.com>
|
||||
ymcki <84055651+ymcki@users.noreply.github.com>
|
||||
yuiseki <yuiseki@gmail.com>
|
||||
yuri@FreeBSD <yurivict@users.noreply.github.com>
|
||||
zakkor <edward.partenie@gmail.com>
|
||||
@@ -963,4 +1043,5 @@ zrm <trustiosity.zrm@gmail.com>
|
||||
杨朱 · Kiki <baofa.fan@daocloud.io>
|
||||
源文雨 <41315874+fumiama@users.noreply.github.com>
|
||||
蕭澧邦 <45505768+shou692199@users.noreply.github.com>
|
||||
谢乃闻 <sienaiwun@users.noreply.github.com>
|
||||
Нияз Гарифзянов <112617865+garrnizon@users.noreply.github.com>
|
||||
|
||||
@@ -16,6 +16,7 @@ endif()
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||
set(LLAMA_STANDALONE ON)
|
||||
@@ -49,6 +50,8 @@ endif()
|
||||
if (MSVC)
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/utf-8>")
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/utf-8>")
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/bigobj>")
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/bigobj>")
|
||||
endif()
|
||||
|
||||
#
|
||||
@@ -77,6 +80,7 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
|
||||
|
||||
# 3rd party libs
|
||||
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
|
||||
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
|
||||
|
||||
# Required for relocatable CMake package
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
|
||||
@@ -185,27 +189,14 @@ set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location o
|
||||
set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
|
||||
set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
|
||||
|
||||
# At the moment some compile definitions are placed within the ggml/src
|
||||
# directory but not exported on the `ggml` target. This could be improved by
|
||||
# determining _precisely_ which defines are necessary for the llama-config
|
||||
# package.
|
||||
#
|
||||
set(GGML_TRANSIENT_DEFINES)
|
||||
get_target_property(GGML_DIRECTORY ggml SOURCE_DIR)
|
||||
get_directory_property(GGML_DIR_DEFINES DIRECTORY ${GGML_DIRECTORY} COMPILE_DEFINITIONS)
|
||||
if (GGML_DIR_DEFINES)
|
||||
list(APPEND GGML_TRANSIENT_DEFINES ${GGML_DIR_DEFINES})
|
||||
endif()
|
||||
get_target_property(GGML_TARGET_DEFINES ggml COMPILE_DEFINITIONS)
|
||||
if (GGML_TARGET_DEFINES)
|
||||
list(APPEND GGML_TRANSIENT_DEFINES ${GGML_TARGET_DEFINES})
|
||||
endif()
|
||||
get_target_property(GGML_LINK_LIBRARIES ggml LINK_LIBRARIES)
|
||||
# all public headers
|
||||
set(LLAMA_PUBLIC_HEADERS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include/llama.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include/llama-cpp.h)
|
||||
set_target_properties(llama PROPERTIES PUBLIC_HEADER "${LLAMA_PUBLIC_HEADERS}")
|
||||
|
||||
set_target_properties(llama
|
||||
PROPERTIES
|
||||
PUBLIC_HEADER "${LLAMA_PUBLIC_HEADERS}")
|
||||
|
||||
install(TARGETS llama LIBRARY PUBLIC_HEADER)
|
||||
|
||||
configure_package_config_file(
|
||||
|
||||
11
Makefile
11
Makefile
@@ -52,6 +52,7 @@ TEST_TARGETS = \
|
||||
tests/test-arg-parser \
|
||||
tests/test-autorelease \
|
||||
tests/test-backend-ops \
|
||||
tests/test-chat \
|
||||
tests/test-chat-template \
|
||||
tests/test-double-float \
|
||||
tests/test-grammar-integration \
|
||||
@@ -595,7 +596,7 @@ ifdef GGML_RPC
|
||||
OBJ_GGML_EXT += ggml/src/ggml-rpc.o
|
||||
endif # GGML_RPC
|
||||
|
||||
OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-wmma*.cu))
|
||||
OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-mma*.cu))
|
||||
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/mmq*.cu))
|
||||
|
||||
ifdef GGML_CUDA_FA_ALL_QUANTS
|
||||
@@ -983,6 +984,7 @@ OBJ_COMMON = \
|
||||
$(DIR_COMMON)/ngram-cache.o \
|
||||
$(DIR_COMMON)/sampling.o \
|
||||
$(DIR_COMMON)/speculative.o \
|
||||
$(DIR_COMMON)/chat.o \
|
||||
$(DIR_COMMON)/build-info.o \
|
||||
$(DIR_COMMON)/json-schema-to-grammar.o
|
||||
|
||||
@@ -1361,6 +1363,8 @@ llama-server: \
|
||||
examples/server/httplib.h \
|
||||
examples/server/index.html.hpp \
|
||||
examples/server/loading.html.hpp \
|
||||
common/chat.cpp \
|
||||
common/chat.hpp \
|
||||
common/chat-template.hpp \
|
||||
common/json.hpp \
|
||||
common/minja.hpp \
|
||||
@@ -1471,6 +1475,11 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \
|
||||
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
tests/test-chat: tests/test-chat.cpp \
|
||||
$(OBJ_ALL)
|
||||
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
tests/test-opt: tests/test-opt.cpp \
|
||||
$(OBJ_GGML)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
|
||||
12
README.md
12
README.md
@@ -16,7 +16,11 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
|
||||
|
||||
## Hot topics
|
||||
|
||||
- **Introducing GGUF-my-LoRA** https://github.com/ggerganov/llama.cpp/discussions/10123
|
||||
- **How to use [MTLResidencySet](https://developer.apple.com/documentation/metal/mtlresidencyset?language=objc) to keep the GPU memory active?** https://github.com/ggerganov/llama.cpp/pull/11427
|
||||
- **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode
|
||||
- Universal tool call support in `llama-server`: https://github.com/ggerganov/llama.cpp/pull/9639
|
||||
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
|
||||
- Introducing GGUF-my-LoRA https://github.com/ggerganov/llama.cpp/discussions/10123
|
||||
- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggerganov/llama.cpp/discussions/9669
|
||||
- Hugging Face GGUF editor: [discussion](https://github.com/ggerganov/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor)
|
||||
|
||||
@@ -92,7 +96,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [Bitnet b1.58 models](https://huggingface.co/1bitLLM)
|
||||
- [x] [Flan T5](https://huggingface.co/models?search=flan-t5)
|
||||
- [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca)
|
||||
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
|
||||
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b) + [GLMEdge-1.5b](https://huggingface.co/THUDM/glm-edge-1.5b-chat) + [GLMEdge-4b](https://huggingface.co/THUDM/glm-edge-4b-chat)
|
||||
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
|
||||
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
|
||||
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
|
||||
@@ -113,6 +117,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [Mini CPM](https://huggingface.co/models?search=MiniCPM)
|
||||
- [x] [Moondream](https://huggingface.co/vikhyatk/moondream2)
|
||||
- [x] [Bunny](https://github.com/BAAI-DCAI/Bunny)
|
||||
- [x] [GLM-EDGE](https://huggingface.co/models?search=glm-edge)
|
||||
- [x] [Qwen2-VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)
|
||||
|
||||
</details>
|
||||
@@ -131,6 +136,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- Rust (more features): [edgenai/llama_cpp-rs](https://github.com/edgenai/llama_cpp-rs)
|
||||
- Rust (nicer API): [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp)
|
||||
- Rust (more direct bindings): [utilityai/llama-cpp-rs](https://github.com/utilityai/llama-cpp-rs)
|
||||
- Rust (automated build from crates.io): [ShelbyJenkins/llm_client](https://github.com/ShelbyJenkins/llm_client)
|
||||
- C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp)
|
||||
- C#/VB.NET (more features - community license): [LM-Kit.NET](https://docs.lm-kit.com/lm-kit-net/index.html)
|
||||
- Scala 3: [donderom/llm4s](https://github.com/donderom/llm4s)
|
||||
@@ -419,7 +425,7 @@ To learn more about model quantization, [read this documentation](examples/quant
|
||||
|
||||
</details>
|
||||
|
||||
[^1]: [examples/perplexity/README.md](examples/perplexity/README.md)
|
||||
[^1]: [examples/perplexity/README.md](./examples/perplexity/README.md)
|
||||
[^2]: [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity)
|
||||
|
||||
## [`llama-bench`](examples/llama-bench)
|
||||
|
||||
@@ -3,159 +3,13 @@ set(LLAMA_BUILD_COMMIT @LLAMA_BUILD_COMMIT@)
|
||||
set(LLAMA_BUILD_NUMBER @LLAMA_BUILD_NUMBER@)
|
||||
set(LLAMA_SHARED_LIB @BUILD_SHARED_LIBS@)
|
||||
|
||||
set(GGML_STATIC @GGML_STATIC@)
|
||||
set(GGML_NATIVE @GGML_NATIVE@)
|
||||
set(GGML_LTO @GGML_LTO@)
|
||||
set(GGML_CCACHE @GGML_CCACHE@)
|
||||
set(GGML_AVX @GGML_AVX@)
|
||||
set(GGML_AVX2 @GGML_AVX2@)
|
||||
set(GGML_AVX512 @GGML_AVX512@)
|
||||
set(GGML_AVX512_VBMI @GGML_AVX512_VBMI@)
|
||||
set(GGML_AVX512_VNNI @GGML_AVX512_VNNI@)
|
||||
set(GGML_AVX512_BF16 @GGML_AVX512_BF16@)
|
||||
set(GGML_AMX_TILE @GGML_AMX_TILE@)
|
||||
set(GGML_AMX_INT8 @GGML_AMX_INT8@)
|
||||
set(GGML_AMX_BF16 @GGML_AMX_BF16@)
|
||||
set(GGML_FMA @GGML_FMA@)
|
||||
set(GGML_LASX @GGML_LASX@)
|
||||
set(GGML_LSX @GGML_LSX@)
|
||||
set(GGML_RVV @GGML_RVV@)
|
||||
set(GGML_SVE @GGML_SVE@)
|
||||
|
||||
set(GGML_ACCELERATE @GGML_ACCELERATE@)
|
||||
set(GGML_OPENMP @GGML_OPENMP@)
|
||||
set(GGML_CPU_HBM @GGML_CPU_HBM@)
|
||||
set(GGML_BLAS_VENDOR @GGML_BLAS_VENDOR@)
|
||||
|
||||
set(GGML_CUDA_FORCE_MMQ @GGML_CUDA_FORCE_MMQ@)
|
||||
set(GGML_CUDA_FORCE_CUBLAS @GGML_CUDA_FORCE_CUBLAS@)
|
||||
set(GGML_CUDA_F16 @GGML_CUDA_F16@)
|
||||
set(GGML_CUDA_PEER_MAX_BATCH_SIZE @GGML_CUDA_PEER_MAX_BATCH_SIZE@)
|
||||
set(GGML_CUDA_NO_PEER_COPY @GGML_CUDA_NO_PEER_COPY@)
|
||||
set(GGML_CUDA_NO_VMM @GGML_CUDA_NO_VMM@)
|
||||
set(GGML_CUDA_FA_ALL_QUANTS @GGML_CUDA_FA_ALL_QUANTS@)
|
||||
set(GGML_CUDA_GRAPHS @GGML_CUDA_GRAPHS@)
|
||||
|
||||
set(GGML_HIP_UMA @GGML_HIP_UMA@)
|
||||
|
||||
set(GGML_VULKAN_CHECK_RESULTS @GGML_VULKAN_CHECK_RESULTS@)
|
||||
set(GGML_VULKAN_DEBUG @GGML_VULKAN_DEBUG@)
|
||||
set(GGML_VULKAN_MEMORY_DEBUG @GGML_VULKAN_MEMORY_DEBUG@)
|
||||
set(GGML_VULKAN_SHADER_DEBUG_INFO @GGML_VULKAN_SHADER_DEBUG_INFO@)
|
||||
set(GGML_VULKAN_PERF @GGML_VULKAN_PERF@)
|
||||
set(GGML_VULKAN_VALIDATE @GGML_VULKAN_VALIDATE@)
|
||||
set(GGML_VULKAN_RUN_TESTS @GGML_VULKAN_RUN_TESTS@)
|
||||
|
||||
set(GGML_METAL_USE_BF16 @GGML_METAL_USE_BF16@)
|
||||
set(GGML_METAL_NDEBUG @GGML_METAL_NDEBUG@)
|
||||
set(GGML_METAL_SHADER_DEBUG @GGML_METAL_SHADER_DEBUG@)
|
||||
set(GGML_METAL_EMBED_LIBRARY @GGML_METAL_EMBED_LIBRARY@)
|
||||
set(GGML_METAL_MACOSX_VERSION_MIN @GGML_METAL_MACOSX_VERSION_MIN@)
|
||||
set(GGML_METAL_STD @GGML_METAL_STD@)
|
||||
|
||||
set(GGML_SYCL_F16 @GGML_SYCL_F16@)
|
||||
set(GGML_SYCL_TARGET @GGML_SYCL_TARGET@)
|
||||
set(GGML_SYCL_DEVICE_ARCH @GGML_SYCL_DEVICE_ARCH@)
|
||||
|
||||
|
||||
@PACKAGE_INIT@
|
||||
|
||||
set_and_check(LLAMA_INCLUDE_DIR "@PACKAGE_LLAMA_INCLUDE_INSTALL_DIR@")
|
||||
set_and_check(LLAMA_LIB_DIR "@PACKAGE_LLAMA_LIB_INSTALL_DIR@")
|
||||
set_and_check(LLAMA_BIN_DIR "@PACKAGE_LLAMA_BIN_INSTALL_DIR@")
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
set(_llama_transient_defines "@GGML_TRANSIENT_DEFINES@")
|
||||
set(_llama_link_deps "")
|
||||
set(_llama_link_opts "")
|
||||
foreach(_ggml_lib ggml ggml-base)
|
||||
string(REPLACE "-" "_" _ggml_lib_var "${_ggml_lib}_LIBRARY")
|
||||
find_library(${_ggml_lib_var} ${_ggml_lib}
|
||||
REQUIRED
|
||||
HINTS ${LLAMA_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH
|
||||
)
|
||||
list(APPEND _llama_link_deps "${${_ggml_lib_var}}")
|
||||
message(STATUS "Found ${${_ggml_lib_var}}")
|
||||
endforeach()
|
||||
|
||||
foreach(backend amx blas cann cpu cuda hip kompute metal musa rpc sycl vulkan)
|
||||
string(TOUPPER "GGML_${backend}" backend_id)
|
||||
set(_ggml_lib "ggml-${backend}")
|
||||
string(REPLACE "-" "_" _ggml_lib_var "${_ggml_lib}_LIBRARY")
|
||||
|
||||
find_library(${_ggml_lib_var} ${_ggml_lib}
|
||||
HINTS ${LLAMA_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH
|
||||
)
|
||||
if(${_ggml_lib_var})
|
||||
list(APPEND _llama_link_deps "${${_ggml_lib_var}}")
|
||||
set(${backend_id} ON)
|
||||
message(STATUS "Found backend ${${_ggml_lib_var}}")
|
||||
else()
|
||||
set(${backend_id} OFF)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if (NOT LLAMA_SHARED_LIB)
|
||||
if (APPLE AND GGML_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED)
|
||||
list(APPEND _llama_link_deps ${ACCELERATE_FRAMEWORK})
|
||||
endif()
|
||||
|
||||
if (GGML_OPENMP)
|
||||
find_package(OpenMP REQUIRED)
|
||||
list(APPEND _llama_link_deps OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
||||
endif()
|
||||
|
||||
if (GGML_CPU_HBM)
|
||||
find_library(memkind memkind REQUIRED)
|
||||
list(APPEND _llama_link_deps memkind)
|
||||
endif()
|
||||
|
||||
if (GGML_BLAS)
|
||||
find_package(BLAS REQUIRED)
|
||||
list(APPEND _llama_link_deps ${BLAS_LIBRARIES})
|
||||
list(APPEND _llama_link_opts ${BLAS_LINKER_FLAGS})
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
endif()
|
||||
|
||||
if (GGML_METAL)
|
||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||
list(APPEND _llama_link_deps ${FOUNDATION_LIBRARY}
|
||||
${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
|
||||
endif()
|
||||
|
||||
if (GGML_VULKAN)
|
||||
find_package(Vulkan REQUIRED)
|
||||
list(APPEND _llama_link_deps Vulkan::Vulkan)
|
||||
endif()
|
||||
|
||||
if (GGML_HIP)
|
||||
find_package(hip REQUIRED)
|
||||
find_package(hipblas REQUIRED)
|
||||
find_package(rocblas REQUIRED)
|
||||
list(APPEND _llama_link_deps hip::host roc::rocblas roc::hipblas)
|
||||
endif()
|
||||
|
||||
if (GGML_SYCL)
|
||||
find_package(DNNL)
|
||||
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||
list(APPEND _llama_link_deps DNNL::dnnl)
|
||||
endif()
|
||||
if (WIN32)
|
||||
find_package(IntelSYCL REQUIRED)
|
||||
find_package(MKL REQUIRED)
|
||||
list(APPEND _llama_link_deps IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
find_package(ggml REQUIRED HINTS ${LLAMA_LIB_DIR}/cmake)
|
||||
|
||||
find_library(llama_LIBRARY llama
|
||||
REQUIRED
|
||||
@@ -167,12 +21,10 @@ add_library(llama UNKNOWN IMPORTED)
|
||||
set_target_properties(llama
|
||||
PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${LLAMA_INCLUDE_DIR}"
|
||||
INTERFACE_LINK_LIBRARIES "${_llama_link_deps}"
|
||||
INTERFACE_LINK_OPTIONS "${_llama_link_opts}"
|
||||
INTERFACE_COMPILE_DEFINITIONS "${_llama_transient_defines}"
|
||||
INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;"
|
||||
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
|
||||
IMPORTED_LOCATION "${llama_LIBRARY}"
|
||||
INTERFACE_COMPILE_FEATURES cxx_std_11
|
||||
POSITION_INDEPENDENT_CODE ON )
|
||||
INTERFACE_COMPILE_FEATURES c_std_90
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
check_required_components(Llama)
|
||||
|
||||
@@ -56,6 +56,8 @@ add_library(${TARGET} STATIC
|
||||
arg.cpp
|
||||
arg.h
|
||||
base64.hpp
|
||||
chat.cpp
|
||||
chat.hpp
|
||||
chat-template.hpp
|
||||
common.cpp
|
||||
common.h
|
||||
@@ -63,6 +65,7 @@ add_library(${TARGET} STATIC
|
||||
console.h
|
||||
json-schema-to-grammar.cpp
|
||||
json.hpp
|
||||
llguidance.cpp
|
||||
log.cpp
|
||||
log.h
|
||||
minja.hpp
|
||||
@@ -89,6 +92,33 @@ if (LLAMA_CURL)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY})
|
||||
endif ()
|
||||
|
||||
if (LLAMA_LLGUIDANCE)
|
||||
include(ExternalProject)
|
||||
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
|
||||
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
|
||||
ExternalProject_Add(llguidance_ext
|
||||
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
|
||||
# v0.6.12:
|
||||
GIT_TAG ced1c9023d47ec194fa977932d35ce65c2ebfc09
|
||||
PREFIX ${CMAKE_BINARY_DIR}/llguidance
|
||||
SOURCE_DIR ${LLGUIDANCE_SRC}
|
||||
BUILD_IN_SOURCE TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND cargo build --release
|
||||
INSTALL_COMMAND ""
|
||||
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/libllguidance.a ${LLGUIDANCE_PATH}/llguidance.h
|
||||
UPDATE_COMMAND ""
|
||||
)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_LLGUIDANCE)
|
||||
|
||||
add_library(llguidance STATIC IMPORTED)
|
||||
set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/libllguidance.a)
|
||||
add_dependencies(llguidance llguidance_ext)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance)
|
||||
endif ()
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC .)
|
||||
target_compile_features (${TARGET} PUBLIC cxx_std_17)
|
||||
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
|
||||
|
||||
@@ -877,7 +877,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.warmup = false;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING}));
|
||||
add_opt(common_arg(
|
||||
{"--spm-infill"},
|
||||
string_format(
|
||||
@@ -1465,15 +1465,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--list-devices"},
|
||||
"print list of available devices and exit",
|
||||
[](common_params &) {
|
||||
printf("Available devices:\n");
|
||||
std::vector<ggml_backend_dev_t> rpc_devices;
|
||||
std::vector<ggml_backend_dev_t> all_devices;
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
|
||||
if (ggml_backend_reg_name(reg) == std::string("RPC")) {
|
||||
rpc_devices.push_back(dev);
|
||||
} else {
|
||||
all_devices.push_back(dev);
|
||||
}
|
||||
}
|
||||
}
|
||||
// insert RPC devices in front
|
||||
all_devices.insert(all_devices.begin(), rpc_devices.begin(), rpc_devices.end());
|
||||
printf("Available devices:\n");
|
||||
for (size_t i = 0; i < all_devices.size(); ++i) {
|
||||
auto * dev = all_devices[i];
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
}
|
||||
exit(0);
|
||||
}
|
||||
));
|
||||
|
||||
@@ -17,21 +17,54 @@ using json = nlohmann::ordered_json;
|
||||
|
||||
namespace minja {
|
||||
|
||||
struct chat_template_caps {
|
||||
bool supports_tools = false;
|
||||
bool supports_tool_calls = false;
|
||||
bool supports_tool_responses = false;
|
||||
bool supports_system_role = false;
|
||||
bool supports_parallel_tool_calls = false;
|
||||
bool supports_tool_call_id = false;
|
||||
// meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
|
||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||
bool requires_object_arguments = false;
|
||||
// CohereForAI/c4ai-command-r-plus simple variant
|
||||
bool requires_non_null_content = false;
|
||||
// MiniMaxAI/MiniMax-Text-01 special
|
||||
bool requires_typed_content = false;
|
||||
};
|
||||
|
||||
struct chat_template_inputs {
|
||||
nlohmann::ordered_json messages;
|
||||
nlohmann::ordered_json tools;
|
||||
bool add_generation_prompt = true;
|
||||
nlohmann::ordered_json extra_context;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
};
|
||||
|
||||
struct chat_template_options {
|
||||
bool apply_polyfills = true;
|
||||
bool use_bos_token = true;
|
||||
bool use_eos_token = true;
|
||||
bool define_strftime_now = true;
|
||||
|
||||
bool polyfill_tools = true;
|
||||
bool polyfill_tool_call_examples = true;
|
||||
bool polyfill_tool_calls = true;
|
||||
bool polyfill_tool_responses = true;
|
||||
bool polyfill_system_role = true;
|
||||
bool polyfill_object_arguments = true;
|
||||
bool polyfill_typed_content = true;
|
||||
};
|
||||
|
||||
class chat_template {
|
||||
public:
|
||||
|
||||
private:
|
||||
bool supports_tools_ = true;
|
||||
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||
bool requires_object_arguments_ = false;
|
||||
bool requires_typed_content_ = false;
|
||||
bool supports_system_role_ = true;
|
||||
bool supports_parallel_tool_calls_ = false;
|
||||
chat_template_caps caps_;
|
||||
std::string source_;
|
||||
std::string bos_token_;
|
||||
std::string eos_token_;
|
||||
std::shared_ptr<minja::TemplateNode> template_root_;
|
||||
std::string tool_call_example_;
|
||||
|
||||
std::string try_raw_render(
|
||||
const nlohmann::ordered_json & messages,
|
||||
@@ -40,16 +73,28 @@ class chat_template {
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
|
||||
{
|
||||
try {
|
||||
auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false);
|
||||
// fprintf(stderr, "Prompt: %s\n", prompt.c_str());
|
||||
chat_template_inputs inputs;
|
||||
inputs.messages = messages;
|
||||
inputs.tools = tools;
|
||||
inputs.add_generation_prompt = add_generation_prompt;
|
||||
inputs.extra_context = extra_context;
|
||||
// Use fixed date for tests
|
||||
inputs.now = std::chrono::system_clock::from_time_t(0);
|
||||
|
||||
chat_template_options opts;
|
||||
opts.apply_polyfills = false;
|
||||
|
||||
auto prompt = apply(inputs, opts);
|
||||
// fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
|
||||
return prompt;
|
||||
} catch (const std::exception & e) {
|
||||
// fprintf(stderr, "Error: %s\n", e.what());
|
||||
// fprintf(stderr, "try_raw_render error: %s\n", e.what());
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
|
||||
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
|
||||
{
|
||||
@@ -58,86 +103,239 @@ class chat_template {
|
||||
/* .lstrip_blocks = */ true,
|
||||
/* .keep_trailing_newline = */ false,
|
||||
});
|
||||
supports_tools_ = source.find("tools") != std::string::npos;
|
||||
|
||||
auto renders_string_arguments =
|
||||
try_raw_render({
|
||||
auto contains = [](const std::string & haystack, const std::string & needle) {
|
||||
return haystack.find(needle) != std::string::npos;
|
||||
};
|
||||
|
||||
const std::string user_needle = "<User Needle>";
|
||||
const std::string sys_needle = "<System Needle>";
|
||||
const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}};
|
||||
const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}};
|
||||
|
||||
caps_.requires_typed_content =
|
||||
!contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle)
|
||||
&& contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle);
|
||||
|
||||
const auto dummy_user_msg = caps_.requires_typed_content
|
||||
? dummy_typed_user_msg
|
||||
: dummy_str_user_msg;
|
||||
const json needle_system_msg = {
|
||||
{"role", "system"},
|
||||
{"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)},
|
||||
};
|
||||
|
||||
caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle);
|
||||
|
||||
auto out = try_raw_render(json::array({
|
||||
dummy_user_msg
|
||||
}), json::array({
|
||||
{
|
||||
{"name", "some_tool"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "some_tool"},
|
||||
{"description", "Some tool."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"arg", {
|
||||
{"type", "string"},
|
||||
{"description", "Some argument."},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({ "arg" })},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
}), false);
|
||||
caps_.supports_tools = contains(out, "some_tool");
|
||||
|
||||
auto make_tool_calls_msg = [&](const json & tool_calls) {
|
||||
return json {
|
||||
{"role", "assistant"},
|
||||
{"content", nullptr},
|
||||
{"tool_calls", tool_calls},
|
||||
};
|
||||
};
|
||||
auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
|
||||
return json {
|
||||
{"id", "call_1___"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"arguments", arguments},
|
||||
{"name", tool_name},
|
||||
}},
|
||||
};
|
||||
};
|
||||
const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
|
||||
|
||||
// Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
|
||||
out = try_raw_render(json::array({
|
||||
dummy_user_msg,
|
||||
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
|
||||
}), {}, false);
|
||||
auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
|
||||
out = try_raw_render(json::array({
|
||||
dummy_user_msg,
|
||||
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
|
||||
}), {}, false);
|
||||
auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
|
||||
|
||||
caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
|
||||
caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
|
||||
auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false);
|
||||
auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false);
|
||||
caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);
|
||||
|
||||
if (caps_.supports_tool_calls) {
|
||||
auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump());
|
||||
auto tc1 = make_tool_call("test_tool1", dummy_args);
|
||||
auto tc2 = make_tool_call("test_tool2", dummy_args);
|
||||
auto out = try_raw_render(json::array({
|
||||
dummy_user_msg,
|
||||
make_tool_calls_msg(json::array({tc1, tc2})),
|
||||
}), {}, false);
|
||||
caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2");
|
||||
|
||||
out = try_raw_render(json::array({
|
||||
dummy_user_msg,
|
||||
make_tool_calls_msg(json::array({tc1})),
|
||||
{
|
||||
{"role", "tool"},
|
||||
{"name", "test_tool1"},
|
||||
{"content", "Some response!"},
|
||||
{"tool_call_id", "call_911_"},
|
||||
}
|
||||
}), {}, false);
|
||||
caps_.supports_tool_responses = contains(out, "Some response!");
|
||||
caps_.supports_tool_call_id = contains(out, "call_911_");
|
||||
}
|
||||
|
||||
try {
|
||||
if (!caps_.supports_tools) {
|
||||
const json user_msg {
|
||||
{"role", "user"},
|
||||
{"content", "Hey"}
|
||||
},
|
||||
{
|
||||
{"content", "Hey"},
|
||||
};
|
||||
const json args {
|
||||
{"arg1", "some_value"},
|
||||
};
|
||||
const json tool_call_msg {
|
||||
{"role", "assistant"},
|
||||
{"content", nullptr},
|
||||
{"tool_calls", json::array({
|
||||
{
|
||||
// TODO: detect if requires numerical id or fixed length == 6 like Nemo
|
||||
{"id", "call_1___"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
|
||||
{"name", "ipython"},
|
||||
{"name", "tool_name"},
|
||||
{"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))},
|
||||
}},
|
||||
},
|
||||
})},
|
||||
};
|
||||
std::string prefix, full;
|
||||
{
|
||||
chat_template_inputs inputs;
|
||||
inputs.messages = json::array({user_msg});
|
||||
inputs.add_generation_prompt = true;
|
||||
prefix = apply(inputs);
|
||||
}
|
||||
}, {}, false).find("{\"code\": \"print") != std::string::npos;
|
||||
if (!renders_string_arguments) {
|
||||
auto renders_object_arguments =
|
||||
try_raw_render({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "Hey"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", json::array({
|
||||
{
|
||||
{"id", "call_1___"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"arguments", {
|
||||
{"code", "print('Hello, World!')"},
|
||||
}},
|
||||
{"name", "ipython"},
|
||||
}},
|
||||
},
|
||||
})},
|
||||
{
|
||||
chat_template_inputs inputs;
|
||||
inputs.messages = json::array({user_msg, tool_call_msg});
|
||||
inputs.add_generation_prompt = false;
|
||||
full = apply(inputs);
|
||||
}
|
||||
|
||||
if (full.find(prefix) != 0) {
|
||||
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
|
||||
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
|
||||
}
|
||||
}, {}, false).find("{\"code\": \"print") != std::string::npos;
|
||||
requires_object_arguments_ = renders_object_arguments;
|
||||
}
|
||||
if (full.find(prefix) != 0) {
|
||||
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
|
||||
}
|
||||
tool_call_example_ = full.substr(prefix.size());
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
|
||||
}
|
||||
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;
|
||||
|
||||
supports_system_role_ = try_raw_render({
|
||||
{{"role", "system"}, {"content", "<System Needle>"}},
|
||||
{{"role", "user"}, {"content", "Hey"}}
|
||||
}, {}, false).find("<System Needle>") != std::string::npos;
|
||||
|
||||
requires_typed_content_ = try_raw_render({{{"role", "user"}, {"content", "Hey"}}}, {}, false).find("Hey") == std::string::npos
|
||||
&& try_raw_render({{{"role", "user"}, {"content", {{{"type", "text"}, {"text", "Hey"}}}}}}, {}, false).find("Hey") != std::string::npos;
|
||||
}
|
||||
|
||||
const std::string & source() const { return source_; }
|
||||
const std::string & bos_token() const { return bos_token_; }
|
||||
const std::string & eos_token() const { return eos_token_; }
|
||||
bool supports_tools() const { return supports_tools_; }
|
||||
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
|
||||
const chat_template_caps & original_caps() const { return caps_; }
|
||||
|
||||
// Deprecated, please use the form with chat_template_inputs and chat_template_options
|
||||
std::string apply(
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
|
||||
bool adjust_inputs = true) const
|
||||
bool apply_polyfills = true)
|
||||
{
|
||||
fprintf(stderr, "[%s] Deprecated!\n", __func__);
|
||||
chat_template_inputs inputs;
|
||||
inputs.messages = messages;
|
||||
inputs.tools = tools;
|
||||
inputs.add_generation_prompt = add_generation_prompt;
|
||||
inputs.extra_context = extra_context;
|
||||
inputs.now = std::chrono::system_clock::now();
|
||||
|
||||
chat_template_options opts;
|
||||
opts.apply_polyfills = apply_polyfills;
|
||||
|
||||
return apply(inputs, opts);
|
||||
}
|
||||
|
||||
std::string apply(
|
||||
const chat_template_inputs & inputs,
|
||||
const chat_template_options & opts = chat_template_options()) const
|
||||
{
|
||||
json actual_messages;
|
||||
|
||||
// First, "fix" messages so they have a chance to be rendered correctly by the template
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_tool_calls = false;
|
||||
auto has_tool_responses = false;
|
||||
auto has_string_content = false;
|
||||
for (const auto & message : inputs.messages) {
|
||||
if (message.contains("tool_calls") && !message["tool_calls"].is_null()) {
|
||||
has_tool_calls = true;
|
||||
}
|
||||
if (message.contains("role") && message["role"] == "tool") {
|
||||
has_tool_responses = true;
|
||||
}
|
||||
if (message.contains("content") && message["content"].is_string()) {
|
||||
has_string_content = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) {
|
||||
auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role;
|
||||
auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools;
|
||||
auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples;
|
||||
auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls;
|
||||
auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses;
|
||||
auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments;
|
||||
auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content;
|
||||
|
||||
auto needs_polyfills = opts.apply_polyfills && (false
|
||||
|| polyfill_system_role
|
||||
|| polyfill_tools
|
||||
|| polyfill_tool_calls
|
||||
|| polyfill_tool_responses
|
||||
|| polyfill_object_arguments
|
||||
|| polyfill_typed_content
|
||||
);
|
||||
|
||||
if (needs_polyfills) {
|
||||
actual_messages = json::array();
|
||||
|
||||
auto add_message = [&](const json & msg) {
|
||||
if (requires_typed_content_ && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
|
||||
if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
|
||||
actual_messages.push_back({
|
||||
{"role", msg.at("role")},
|
||||
{"content", {{
|
||||
@@ -160,7 +358,17 @@ class chat_template {
|
||||
pending_system.clear();
|
||||
}
|
||||
};
|
||||
for (const auto & message_ : messages) {
|
||||
|
||||
json adjusted_messages;
|
||||
if (polyfill_tools) {
|
||||
adjusted_messages = add_system(inputs.messages,
|
||||
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
|
||||
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_));
|
||||
} else {
|
||||
adjusted_messages = inputs.messages;
|
||||
}
|
||||
|
||||
for (const auto & message_ : adjusted_messages) {
|
||||
auto message = message_;
|
||||
if (!message.contains("role") || !message.contains("content")) {
|
||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||
@@ -168,16 +376,22 @@ class chat_template {
|
||||
std::string role = message.at("role");
|
||||
|
||||
if (message.contains("tool_calls")) {
|
||||
if (requires_object_arguments_ || !supports_tools_) {
|
||||
if (polyfill_object_arguments || polyfill_tool_calls) {
|
||||
for (auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call["type"] == "function") {
|
||||
auto & function = tool_call.at("function");
|
||||
std::string arguments = function.at("arguments");
|
||||
function["arguments"] = json::parse(arguments);
|
||||
auto & arguments = function.at("arguments");
|
||||
if (arguments.is_string()) {
|
||||
try {
|
||||
arguments = json::parse(arguments.get<std::string>());
|
||||
} catch (const std::exception & ecvt) {
|
||||
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!supports_tools_) {
|
||||
if (polyfill_tool_calls) {
|
||||
auto content = message.at("content");
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tool_call : message.at("tool_calls")) {
|
||||
@@ -204,14 +418,16 @@ class chat_template {
|
||||
message.erase("tool_calls");
|
||||
}
|
||||
}
|
||||
if (!supports_tools_ && role == "tool") {
|
||||
if (polyfill_tool_responses && role == "tool") {
|
||||
message["role"] = "user";
|
||||
auto obj = json {
|
||||
{"tool_response", {
|
||||
{"tool", message.at("name")},
|
||||
{"content", message.at("content")},
|
||||
}},
|
||||
};
|
||||
if (message.contains("name")) {
|
||||
obj["tool_response"]["name"] = message.at("name");
|
||||
}
|
||||
if (message.contains("tool_call_id")) {
|
||||
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
|
||||
}
|
||||
@@ -219,7 +435,7 @@ class chat_template {
|
||||
message.erase("name");
|
||||
}
|
||||
|
||||
if (!message["content"].is_null() && !supports_system_role_) {
|
||||
if (!message["content"].is_null() && polyfill_system_role) {
|
||||
std::string content = message.at("content");
|
||||
if (role == "system") {
|
||||
if (!pending_system.empty()) pending_system += "\n";
|
||||
@@ -240,28 +456,59 @@ class chat_template {
|
||||
}
|
||||
flush_sys();
|
||||
} else {
|
||||
actual_messages = messages;
|
||||
actual_messages = inputs.messages;
|
||||
}
|
||||
|
||||
auto context = minja::Context::make(json({
|
||||
{"messages", actual_messages},
|
||||
{"add_generation_prompt", add_generation_prompt},
|
||||
{"bos_token", bos_token_},
|
||||
{"eos_token", eos_token_},
|
||||
{"add_generation_prompt", inputs.add_generation_prompt},
|
||||
}));
|
||||
context->set("bos_token", opts.use_bos_token ? bos_token_ : "");
|
||||
context->set("eos_token", opts.use_eos_token ? eos_token_ : "");
|
||||
if (opts.define_strftime_now) {
|
||||
auto now = inputs.now;
|
||||
context->set("strftime_now", Value::callable([now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
|
||||
args.expectArgs("strftime_now", {1, 1}, {0, 0});
|
||||
auto format = args.args[0].get<std::string>();
|
||||
|
||||
if (!tools.is_null()) {
|
||||
auto tools_val = minja::Value(tools);
|
||||
context->set("tools", tools_val);
|
||||
auto time = std::chrono::system_clock::to_time_t(now);
|
||||
auto local_time = *std::localtime(&time);
|
||||
std::ostringstream ss;
|
||||
ss << std::put_time(&local_time, format.c_str());
|
||||
return ss.str();
|
||||
}));
|
||||
}
|
||||
if (!extra_context.is_null()) {
|
||||
for (auto & kv : extra_context.items()) {
|
||||
minja::Value val(kv.value());
|
||||
context->set(kv.key(), val);
|
||||
if (!inputs.tools.is_null()) {
|
||||
context->set("tools", minja::Value(inputs.tools));
|
||||
}
|
||||
if (!inputs.extra_context.is_null()) {
|
||||
for (auto & kv : inputs.extra_context.items()) {
|
||||
context->set(kv.key(), minja::Value(kv.value()));
|
||||
}
|
||||
}
|
||||
|
||||
return template_root_->render(context);
|
||||
auto ret = template_root_->render(context);
|
||||
// fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
|
||||
// fprintf(stderr, "apply: %s\n\n", ret.c_str());
|
||||
return ret;
|
||||
}
|
||||
|
||||
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
|
||||
json messages_with_system = messages;
|
||||
|
||||
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
|
||||
std::string existing_system = messages_with_system.at(0).at("content");
|
||||
messages_with_system[0] = json {
|
||||
{"role", "system"},
|
||||
{"content", existing_system + "\n\n" + system_prompt},
|
||||
};
|
||||
} else {
|
||||
messages_with_system.insert(messages_with_system.begin(), json {
|
||||
{"role", "system"},
|
||||
{"content", system_prompt},
|
||||
});
|
||||
}
|
||||
return messages_with_system;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
966
common/chat.cpp
Normal file
966
common/chat.cpp
Normal file
@@ -0,0 +1,966 @@
|
||||
#include "chat.hpp"
|
||||
#include "chat-template.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "minja.hpp"
|
||||
|
||||
std::string common_chat_format_name(common_chat_format format) {
|
||||
switch (format) {
|
||||
case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
|
||||
case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
|
||||
case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
|
||||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
}
|
||||
|
||||
const common_grammar_options grammar_options {
|
||||
/* .dotall = */ false,
|
||||
/* .compact_spaces = */ false,
|
||||
// /* .compact_spaces = */ true,
|
||||
};
|
||||
|
||||
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||
std::size_t position;
|
||||
bool found_error;
|
||||
|
||||
json_error_locator() : position(0), found_error(false) {}
|
||||
|
||||
bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
|
||||
this->position = position - 1;
|
||||
this->found_error = true;
|
||||
return false;
|
||||
}
|
||||
bool null() override { return true; }
|
||||
bool boolean(bool) override { return true; }
|
||||
bool number_integer(number_integer_t) override { return true; }
|
||||
bool number_unsigned(number_unsigned_t) override { return true; }
|
||||
bool number_float(number_float_t, const string_t &) override { return true; }
|
||||
bool string(string_t &) override { return true; }
|
||||
bool binary(binary_t &) override { return true; }
|
||||
bool start_object(std::size_t) override { return true; }
|
||||
bool key(string_t &) override { return true; }
|
||||
bool end_object() override { return true; }
|
||||
bool start_array(std::size_t) override { return true; }
|
||||
bool end_array() override { return true; }
|
||||
};
|
||||
json_error_locator err_loc;
|
||||
json::sax_parse(it, end, &err_loc);
|
||||
|
||||
std::string::const_iterator temptative_end;
|
||||
if (err_loc.found_error) {
|
||||
temptative_end = it + err_loc.position;
|
||||
} else {
|
||||
temptative_end = end;
|
||||
}
|
||||
std::string json_sub {it, temptative_end};
|
||||
try {
|
||||
out = json::parse(json_sub);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
} catch (const std::exception &) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||
* Aggregates the prefix, suffix and in-between text into the content.
|
||||
*/
|
||||
static common_chat_msg parse_json_tool_calls(
|
||||
const std::string& input,
|
||||
const std::optional<std::regex> & trigger_opt,
|
||||
const std::regex & function_regex,
|
||||
const std::regex & close_regex) {
|
||||
std::smatch match;
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
|
||||
|
||||
auto end = input.end();
|
||||
auto it = input.begin();
|
||||
|
||||
if (trigger_opt) {
|
||||
if (!std::regex_search(it, end, match, *trigger_opt)) {
|
||||
result.content = input;
|
||||
return result;
|
||||
}
|
||||
result.content = match.prefix().str();
|
||||
it = match.suffix().first;
|
||||
}
|
||||
|
||||
while (it != end) {
|
||||
std::sregex_iterator rend;
|
||||
std::sregex_iterator rit(it, end, function_regex);
|
||||
if (rit == rend) {
|
||||
fprintf(stderr, "No more tool calls found\n");
|
||||
result.content += std::string(it, end);
|
||||
break;
|
||||
}
|
||||
auto name = rit->str(1);
|
||||
result.content += std::string(it, rit->prefix().second);
|
||||
it = rit->suffix().first;
|
||||
|
||||
json arguments;
|
||||
if (!parse_json(it, end, arguments)) {
|
||||
throw std::runtime_error("Failed to parse json tool call arguments");
|
||||
}
|
||||
if (!std::regex_search(it, end, match, close_regex)) {
|
||||
throw std::runtime_error("Malformed input, missing closing pattern");
|
||||
}
|
||||
it = match.suffix().first;
|
||||
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
|
||||
auto content_end = input.find(prefix);
|
||||
size_t tc_start = std::string::npos;
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
const auto process_tool_calls = [&](const json & tool_calls) {
|
||||
for (const auto & tool_call : tool_calls) {
|
||||
const auto & arguments = tool_call["arguments"];
|
||||
result.tool_calls.push_back({
|
||||
tool_call["name"],
|
||||
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
tool_call.contains("id") ? tool_call["id"] : "",
|
||||
});
|
||||
}
|
||||
};
|
||||
if (content_end == std::string::npos) {
|
||||
result.content = input;
|
||||
} else {
|
||||
tc_start = content_end + prefix.size() - rstrip_prefix;
|
||||
result.content = input.substr(0, content_end);
|
||||
auto tool_calls = json::parse(input.substr(tc_start));
|
||||
process_tool_calls(tool_calls);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
|
||||
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
fn(tool);
|
||||
}
|
||||
}
|
||||
|
||||
static std::string apply(
|
||||
const common_chat_template & tmpl,
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json())
|
||||
{
|
||||
minja::chat_template_inputs tmpl_inputs;
|
||||
tmpl_inputs.messages = messages;
|
||||
tmpl_inputs.tools = tools;
|
||||
tmpl_inputs.add_generation_prompt = add_generation_prompt;
|
||||
tmpl_inputs.extra_context = extra_context;
|
||||
// TODO: add flag to control date/time, if only for testing purposes.
|
||||
// tmpl_inputs.now = std::chrono::system_clock::now();
|
||||
|
||||
minja::chat_template_options tmpl_opts;
|
||||
tmpl_opts.use_bos_token = false;
|
||||
tmpl_opts.use_eos_token = false;
|
||||
|
||||
return tmpl.apply(tmpl_inputs, tmpl_opts);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
auto tool_call_schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
auto tool_schema = json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"arguments", function["parameters"]},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments"})},
|
||||
};
|
||||
if (function.contains("description")) {
|
||||
tool_schema["description"] = function["description"];
|
||||
}
|
||||
if (inputs.parallel_tool_calls) {
|
||||
tool_schema["properties"]["id"] = {
|
||||
{"type", "string"},
|
||||
{"minLength", 4},
|
||||
};
|
||||
tool_schema["required"].push_back("id");
|
||||
}
|
||||
tool_call_schemas.emplace_back(tool_schema);
|
||||
});
|
||||
const auto tool_call =
|
||||
inputs.parallel_tool_calls
|
||||
? json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"tool_calls", {
|
||||
{"type", "array"},
|
||||
{"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
|
||||
{"anyOf", tool_call_schemas},
|
||||
}},
|
||||
{"minItems", 1},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({"tool_calls"})},
|
||||
}
|
||||
: json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
|
||||
{"anyOf", tool_call_schemas},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({"tool_call"})},
|
||||
};
|
||||
const auto schema =
|
||||
inputs.tool_choice != "required"
|
||||
? json {
|
||||
{"anyOf", json::array({
|
||||
tool_call,
|
||||
{
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"response", inputs.json_schema.is_null()
|
||||
? json {{"type", "string"}}
|
||||
: inputs.json_schema
|
||||
},
|
||||
}},
|
||||
{"required", json::array({"response"})},
|
||||
},
|
||||
})}
|
||||
}
|
||||
: tool_call;
|
||||
|
||||
data.grammar_lazy = false;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
builder.add_schema("root", schema);
|
||||
}, grammar_options);
|
||||
|
||||
auto tweaked_messages = common_chat_template::add_system(
|
||||
inputs.messages,
|
||||
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
|
||||
|
||||
data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_GENERIC;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_generic(const std::string & input) {
|
||||
json data = json::parse(input);
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
if (data.contains("tool_calls")) {
|
||||
for (const auto & tool_call : data["tool_calls"]) {
|
||||
result.tool_calls.push_back({
|
||||
tool_call["name"],
|
||||
tool_call["arguments"].dump(),
|
||||
tool_call.contains("id") ? tool_call["id"] : "",
|
||||
});
|
||||
}
|
||||
} else if (data.contains("tool_call")) {
|
||||
result.tool_calls.push_back({
|
||||
data["tool_call"]["name"],
|
||||
data["tool_call"]["arguments"].dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
} else if (data.contains("response")) {
|
||||
const auto & response = data["response"];
|
||||
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
// Important note: the model is probably trained to take a JSON stringified arguments value.
|
||||
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"arguments", function["parameters"]},
|
||||
{"id", {
|
||||
{"type", "string"},
|
||||
// Nemo's template expects a 9-character alphanumeric ID.
|
||||
{"pattern", "^[a-zA-Z0-9]{9}$"},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments", "id"})},
|
||||
});
|
||||
});
|
||||
auto schema = json {
|
||||
{"type", "array"},
|
||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||
{"minItems", 1},
|
||||
};
|
||||
if (!inputs.parallel_tool_calls) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) {
|
||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"tool_call_id", {
|
||||
{"type", "string"},
|
||||
// Command-R's template expects an integer string.
|
||||
{"pattern", "^[0-9]{1,10}$"},
|
||||
}},
|
||||
{"tool_name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"parameters", function["parameters"]},
|
||||
}},
|
||||
{"required", json::array({"tool_call_id", "tool_name", "parameters"})},
|
||||
});
|
||||
});
|
||||
auto schema = json {
|
||||
{"type", "array"},
|
||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||
{"minItems", 1},
|
||||
};
|
||||
if (!inputs.parallel_tool_calls) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false});
|
||||
data.preserved_tokens = {
|
||||
"<|START_RESPONSE|>",
|
||||
"<|END_RESPONSE|>",
|
||||
"<|START_THINKING|>",
|
||||
"<|END_THINKING|>",
|
||||
"<|END_ACTION|>",
|
||||
};
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_command_r7b(const std::string & input) {
|
||||
static std::regex response_regex("<\\|START_RESPONSE\\|>([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
|
||||
static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
|
||||
std::smatch match;
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
if (std::regex_match(input, match, response_regex)) {
|
||||
result.content = match[1].str();
|
||||
} else if (std::regex_match(input, match, thought_action_regex)) {
|
||||
result.tool_plan = match[1].str();
|
||||
auto actions_str = match[2].str();
|
||||
auto actions = json::parse(actions_str);
|
||||
for (const auto & action : actions) {
|
||||
result.tool_calls.push_back({
|
||||
/* .name = */ action["tool_name"],
|
||||
/* .arguments = */ action["parameters"].dump(),
|
||||
/* .id = */ action["tool_call_id"],
|
||||
});
|
||||
}
|
||||
} else {
|
||||
LOG_ERR("Failed to parse command_r output");
|
||||
result.content = input;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
|
||||
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
|
||||
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
|
||||
}
|
||||
const auto & parameters_properties = parameters.at("properties");
|
||||
const auto & parameters_required = parameters.at("required");
|
||||
for (const auto & prop : expected_properties) {
|
||||
if (!parameters_properties.contains(prop)) {
|
||||
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
|
||||
}
|
||||
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
|
||||
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
|
||||
}
|
||||
}
|
||||
if (parameters_properties.size() != expected_properties.size()) {
|
||||
throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", "));
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
|
||||
auto builtin_tools = json::array();
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
||||
if (name == "wolfram_alpha") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
||||
expect_tool_parameters(name, parameters, {"query"});
|
||||
} else if (name == "web_search" || name == "brave_search") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
||||
expect_tool_parameters(name, parameters, {"query"});
|
||||
} else if (name == "python" || name == "code_interpreter") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
|
||||
expect_tool_parameters(name, parameters, {"code"});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> kvs;
|
||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
|
||||
}
|
||||
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
|
||||
builtin_tools.push_back(name);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
builder.resolve_refs(parameters);
|
||||
|
||||
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
||||
if (allow_python_tag_builtin_tools) {
|
||||
handle_builtin_tool(name, parameters);
|
||||
}
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"{\" space "
|
||||
"( \"\\\"type\\\":\" space \"\\\"function\\\",\" space )? "
|
||||
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||
builder.add_schema(name + "-args", parameters) +
|
||||
" \"}\""));
|
||||
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
|
||||
});
|
||||
data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
|
||||
if (!builtin_tools.empty()) {
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
}
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
}, grammar_options);
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
||||
{"tools_in_user_message", false},
|
||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||
});
|
||||
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
||||
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
||||
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
||||
// TODO: tighten & simplify the parser, don't accept leading text context.
|
||||
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
||||
|
||||
if (with_builtin_tools) {
|
||||
std::smatch match;
|
||||
if (std::regex_match(input, match, builtin_call_regex)) {
|
||||
auto name = match[1].str();
|
||||
auto raw_args = match[2].str();
|
||||
|
||||
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
|
||||
auto it_eq = raw_args.find('=');
|
||||
auto arg_name = raw_args.substr(0, it_eq);
|
||||
auto arg_value_str = raw_args.substr(it_eq + 1);
|
||||
auto arg_value = json::parse(arg_value_str);
|
||||
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ match[1],
|
||||
/* .arguments = */ (json {
|
||||
{arg_name, arg_value},
|
||||
}).dump(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
|
||||
});
|
||||
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
|
||||
data.preserved_tokens = {
|
||||
"<|tool▁sep|>",
|
||||
"<|tool▁call▁end|>",
|
||||
};
|
||||
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
|
||||
}, grammar_options);
|
||||
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.prompt = prompt;
|
||||
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) {
|
||||
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||
static std::regex close_regex("```<|tool▁call▁end|>");
|
||||
return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
fprintf(stderr, "%s\n", __func__);
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
});
|
||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"arguments", function["parameters"]},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments", "id"})},
|
||||
});
|
||||
});
|
||||
auto schema = json {
|
||||
{"type", "array"},
|
||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||
{"minItems", 1},
|
||||
};
|
||||
if (!inputs.parallel_tool_calls) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
||||
data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
|
||||
} else {
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
}
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_firefunction_v2(const std::string & input) {
|
||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> first_tool_rules;
|
||||
std::vector<std::string> subsequent_tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
||||
data.grammar_triggers.push_back({name, /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
|
||||
});
|
||||
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||
if (inputs.parallel_tool_calls) {
|
||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
||||
} else {
|
||||
builder.add_rule("root", first_rule);
|
||||
}
|
||||
|
||||
}, grammar_options);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
|
||||
auto expected_it = expected.begin();
|
||||
auto tmp_it = it;
|
||||
while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
|
||||
++tmp_it;
|
||||
++expected_it;
|
||||
}
|
||||
if (expected_it == expected.end()) {
|
||||
it = tmp_it;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
|
||||
std::string content;
|
||||
auto it = input.begin();
|
||||
const auto end = input.end();
|
||||
|
||||
if (consume(it, end, "all\n")) {
|
||||
std::smatch match;
|
||||
if (std::regex_search(it, end, match, function_regex)) {
|
||||
auto fun_it = match.prefix().second;
|
||||
content = std::string(it, fun_it);
|
||||
it = fun_it;
|
||||
} else {
|
||||
common_chat_msg res;
|
||||
res.role = "assistant";
|
||||
res.content = std::string(it, end);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
// TODO: tighten & simplify.
|
||||
try {
|
||||
auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex);
|
||||
res.content = content + res.content;
|
||||
return res;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("Failed to parse functionary v3.2 input: %s\n", e.what());
|
||||
common_chat_msg res;
|
||||
res.role = "assistant";
|
||||
res.content = input;
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
common_chat_params data;
|
||||
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
||||
std::string python_code_argument_name;
|
||||
auto has_raw_python = false;
|
||||
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
const auto & parameters = function["parameters"];
|
||||
std::string name = function["name"];
|
||||
if (name == "python" || name == "ipython") {
|
||||
if (!parameters.contains("type")) {
|
||||
throw std::runtime_error("Missing type in python tool");
|
||||
}
|
||||
has_raw_python = true;
|
||||
auto type = parameters.at("type");
|
||||
if (type == "object") {
|
||||
auto properties = parameters.at("properties");
|
||||
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
||||
if (it.value().at("type") == "string") {
|
||||
if (!python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("Multiple string arguments found in python tool");
|
||||
}
|
||||
python_code_argument_name = it.key();
|
||||
}
|
||||
}
|
||||
if (python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("No string argument found in python tool");
|
||||
}
|
||||
} else if (type != "string") {
|
||||
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
||||
}
|
||||
}
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||
});
|
||||
if (has_raw_python) {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
}
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
|
||||
}, grammar_options);
|
||||
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
// TODO: if (has_raw_python)
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
|
||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, python_tag_regex)) {
|
||||
auto code = match[1].str();
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ (json {{"code", code}}).dump(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||
static std::regex close_regex(R"(</function>)");
|
||||
// TODO: tighten & simplify.
|
||||
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
builder.resolve_refs(parameters);
|
||||
tool_rules.push_back(builder.add_schema(name + "-call", {
|
||||
{"type", "object"},
|
||||
{"properties", json {
|
||||
{"name", json {{"const", name}}},
|
||||
{"arguments", parameters},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments"})},
|
||||
}));
|
||||
});
|
||||
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
|
||||
data.preserved_tokens = { "</tool_call>" };
|
||||
}, grammar_options);
|
||||
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) {
|
||||
try {
|
||||
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
||||
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
||||
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
||||
|
||||
auto end = input.end();
|
||||
std::sregex_iterator rend;
|
||||
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
||||
if (rit == rend) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
}
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
result.content = rit->prefix();
|
||||
|
||||
auto it = rit->suffix().first;
|
||||
while (it != end) {
|
||||
json call;
|
||||
if (!parse_json(it, end, call)) {
|
||||
throw std::runtime_error("Failed to parse json tool call");
|
||||
}
|
||||
const auto & arguments = call["arguments"];
|
||||
result.tool_calls.push_back({
|
||||
call["name"],
|
||||
arguments.dump(),
|
||||
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
rit = {it, end, middle_pattern};
|
||||
if (rit != rend) {
|
||||
it = rit->suffix().first;
|
||||
} else {
|
||||
rit = {it, end, end_pattern};
|
||||
if (rit == rend) {
|
||||
throw std::runtime_error("Malformed input, missing </tool_call>");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
} catch (const std::exception & e) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
data.grammar_lazy = false;
|
||||
if (!inputs.json_schema.is_null()) {
|
||||
if (!inputs.grammar.empty()) {
|
||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||
}
|
||||
data.grammar = json_schema_to_grammar(inputs.json_schema);
|
||||
} else {
|
||||
data.grammar = inputs.grammar.empty();
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
auto has_tools = !inputs.tools.is_null() && inputs.tool_choice != "none";
|
||||
LOG_DBG("[%s] has_tools=%s\n", __func__, has_tools ? "true" : "false");
|
||||
|
||||
if (has_tools && !inputs.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
}
|
||||
|
||||
const auto & src = tmpl.source();
|
||||
if (src.find(">>>all") != std::string::npos) {
|
||||
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
|
||||
return common_chat_params_init_functionary_v3_2(tmpl, inputs);
|
||||
}
|
||||
if (src.find(" functools[") != std::string::npos) {
|
||||
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
|
||||
return common_chat_params_init_firefunction_v2(tmpl, inputs);
|
||||
}
|
||||
|
||||
if (!has_tools) {
|
||||
return common_chat_params_init_without_tools(tmpl, inputs);
|
||||
}
|
||||
|
||||
if (src.find("<tool_call>") != std::string::npos) {
|
||||
return common_chat_params_init_hermes_2_pro(tmpl, inputs);
|
||||
}
|
||||
if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
|
||||
}
|
||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
|
||||
}
|
||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
||||
return common_chat_params_init_deepseek_r1(tmpl, inputs);
|
||||
}
|
||||
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||
return common_chat_params_init_mistral_nemo(tmpl, inputs);
|
||||
}
|
||||
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) {
|
||||
return common_chat_params_init_command_r7b(tmpl, inputs);
|
||||
}
|
||||
return common_chat_params_init_generic(tmpl, inputs);
|
||||
}
|
||||
|
||||
static common_chat_msg common_chat_parse_content_only(const std::string & input) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
|
||||
switch (format) {
|
||||
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
|
||||
return common_chat_parse_content_only(input);
|
||||
case COMMON_CHAT_FORMAT_GENERIC:
|
||||
return common_chat_parse_generic(input);
|
||||
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
|
||||
return common_chat_parse_mistral_nemo(input);
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X:
|
||||
return common_chat_parse_llama_3_1(input);
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
|
||||
return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true);
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
|
||||
return common_chat_parse_deepseek_r1(input);
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
|
||||
return common_chat_parse_functionary_v3_2(input);
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
|
||||
return common_chat_parse_functionary_v3_1_llama_3_1(input);
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
|
||||
return common_chat_parse_hermes_2_pro(input);
|
||||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
|
||||
return common_chat_parse_firefunction_v2(input);
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B:
|
||||
return common_chat_parse_command_r7b(input);
|
||||
default:
|
||||
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
|
||||
}
|
||||
}
|
||||
52
common/chat.hpp
Normal file
52
common/chat.hpp
Normal file
@@ -0,0 +1,52 @@
|
||||
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include <json.hpp>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
struct common_chat_inputs {
|
||||
json messages;
|
||||
json tools;
|
||||
json tool_choice;
|
||||
json json_schema;
|
||||
bool parallel_tool_calls;
|
||||
bool stream;
|
||||
std::string grammar;
|
||||
bool add_generation_prompt = true;
|
||||
};
|
||||
|
||||
enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
COMMON_CHAT_FORMAT_GENERIC,
|
||||
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
struct common_chat_params {
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
json prompt;
|
||||
std::string grammar;
|
||||
bool grammar_lazy = false;
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
};
|
||||
|
||||
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
|
||||
std::string common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "json.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
#include "chat.hpp"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -1774,11 +1775,13 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
try {
|
||||
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
|
||||
chat_template.apply({{
|
||||
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
|
||||
common_chat_inputs inputs;
|
||||
inputs.messages = json::array({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}}, json(), true);
|
||||
}});
|
||||
common_chat_params_init(chat_template, inputs);
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
||||
@@ -1800,7 +1803,10 @@ std::string common_chat_apply_template(
|
||||
for (const auto & msg : msgs) {
|
||||
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
||||
}
|
||||
return tmpl.apply(messages, /* tools= */ json(), add_ass);
|
||||
common_chat_inputs inputs;
|
||||
inputs.messages = messages;
|
||||
inputs.add_generation_prompt = add_ass;
|
||||
return common_chat_params_init(tmpl, inputs).prompt;
|
||||
}
|
||||
|
||||
int alloc_size = 0;
|
||||
@@ -1855,19 +1861,27 @@ std::string common_chat_format_single(
|
||||
|
||||
std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
|
||||
std::vector<common_chat_msg> msgs = {
|
||||
{"system", "You are a helpful assistant"},
|
||||
{"user", "Hello"},
|
||||
{"assistant", "Hi there"},
|
||||
{"user", "How are you?"},
|
||||
{"system", "You are a helpful assistant", {}},
|
||||
{"user", "Hello", {}},
|
||||
{"assistant", "Hi there", {}},
|
||||
{"user", "How are you?", {}},
|
||||
};
|
||||
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
|
||||
}
|
||||
|
||||
#define CHATML_TEMPLATE_SRC \
|
||||
"{%- for message in messages -%}\n" \
|
||||
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
|
||||
"{%- endfor -%}\n" \
|
||||
"{%- if add_generation_prompt -%}\n" \
|
||||
" {{- '<|im_start|>assistant\n' -}}\n" \
|
||||
"{%- endif -%}"
|
||||
|
||||
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
|
||||
{
|
||||
auto vocab = llama_model_get_vocab(model);
|
||||
std::string default_template_src = chat_template_override;
|
||||
std::string template_tool_use_src = chat_template_override;
|
||||
std::string default_template_src;
|
||||
std::string template_tool_use_src;
|
||||
|
||||
bool has_explicit_template = !chat_template_override.empty();
|
||||
if (chat_template_override.empty()) {
|
||||
auto str = llama_model_chat_template(model, /* name */ nullptr);
|
||||
@@ -1880,21 +1894,17 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model
|
||||
template_tool_use_src = str;
|
||||
has_explicit_template = true;
|
||||
}
|
||||
} else {
|
||||
default_template_src = chat_template_override;
|
||||
}
|
||||
if (default_template_src.empty() || default_template_src == "chatml") {
|
||||
if (!template_tool_use_src.empty()) {
|
||||
default_template_src = template_tool_use_src;
|
||||
} else {
|
||||
default_template_src = R"(
|
||||
{%- for message in messages -%}
|
||||
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- "<|im_start|>assistant\n" -}}
|
||||
{%- endif -%}
|
||||
)";
|
||||
default_template_src = CHATML_TEMPLATE_SRC;
|
||||
}
|
||||
}
|
||||
auto vocab = llama_model_get_vocab(model);
|
||||
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
|
||||
if (token == LLAMA_TOKEN_NULL) {
|
||||
if (default_template_src.find(jinja_variable_name) != std::string::npos
|
||||
@@ -1908,13 +1918,22 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model
|
||||
};
|
||||
auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
|
||||
auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
|
||||
return {
|
||||
has_explicit_template,
|
||||
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
|
||||
template_tool_use_src.empty()
|
||||
? nullptr
|
||||
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos)
|
||||
};
|
||||
try {
|
||||
return {
|
||||
has_explicit_template,
|
||||
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
|
||||
template_tool_use_src.empty()
|
||||
? nullptr
|
||||
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
|
||||
};
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what());
|
||||
return {
|
||||
has_explicit_template,
|
||||
std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
|
||||
nullptr,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
@@ -109,6 +110,11 @@ enum common_conversation_mode {
|
||||
COMMON_CONVERSATION_MODE_AUTO = 2,
|
||||
};
|
||||
|
||||
struct common_grammar_trigger {
|
||||
std::string word;
|
||||
bool at_start;
|
||||
};
|
||||
|
||||
// sampling parameters
|
||||
struct common_params_sampling {
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||
@@ -154,7 +160,11 @@ struct common_params_sampling {
|
||||
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
||||
};
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
bool grammar_lazy = false;
|
||||
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
|
||||
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
|
||||
std::set<llama_token> preserved_tokens;
|
||||
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
|
||||
@@ -602,10 +612,18 @@ std::string common_detokenize(
|
||||
// Chat template utils
|
||||
//
|
||||
|
||||
struct common_tool_call {
|
||||
std::string name;
|
||||
std::string arguments;
|
||||
std::string id;
|
||||
};
|
||||
|
||||
// same with llama_chat_message, but uses std::string
|
||||
struct common_chat_msg {
|
||||
std::string role;
|
||||
std::string content;
|
||||
std::vector<common_tool_call> tool_calls;
|
||||
std::string tool_plan = "";
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
|
||||
@@ -343,7 +343,7 @@ static std::string format_literal(const std::string & literal) {
|
||||
|
||||
class SchemaConverter {
|
||||
private:
|
||||
friend std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
|
||||
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
||||
std::function<json(const std::string &)> _fetch_json;
|
||||
bool _dotall;
|
||||
std::map<std::string, std::string> _rules;
|
||||
@@ -764,10 +764,11 @@ private:
|
||||
public:
|
||||
SchemaConverter(
|
||||
const std::function<json(const std::string &)> & fetch_json,
|
||||
bool dotall)
|
||||
bool dotall,
|
||||
bool compact_spaces)
|
||||
: _fetch_json(fetch_json), _dotall(dotall)
|
||||
{
|
||||
_rules["space"] = SPACE_RULE;
|
||||
_rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
|
||||
}
|
||||
|
||||
void resolve_refs(json & schema, const std::string & url) {
|
||||
@@ -990,17 +991,24 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
std::string json_schema_to_grammar(const json & schema) {
|
||||
return build_grammar([&](const llama_grammar_builder & callbacks) {
|
||||
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
if (!force_gbnf) {
|
||||
return "%llguidance {}\nstart: %json " + schema.dump();
|
||||
}
|
||||
#else
|
||||
(void)force_gbnf;
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
return build_grammar([&](const common_grammar_builder & callbacks) {
|
||||
auto copy = schema;
|
||||
callbacks.resolve_refs(copy);
|
||||
callbacks.add_schema("", copy);
|
||||
});
|
||||
}
|
||||
|
||||
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {
|
||||
SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false);
|
||||
llama_grammar_builder builder {
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
|
||||
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
|
||||
common_grammar_builder builder {
|
||||
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
||||
return converter._add_rule(name, rule);
|
||||
},
|
||||
|
||||
@@ -5,12 +5,18 @@
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
|
||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
|
||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
||||
bool force_gbnf = false);
|
||||
|
||||
struct llama_grammar_builder {
|
||||
struct common_grammar_builder {
|
||||
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
||||
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
||||
std::function<void(nlohmann::ordered_json &)> resolve_refs;
|
||||
};
|
||||
|
||||
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
|
||||
struct common_grammar_options {
|
||||
bool dotall = false;
|
||||
bool compact_spaces = false;
|
||||
};
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||
|
||||
270
common/llguidance.cpp
Normal file
270
common/llguidance.cpp
Normal file
@@ -0,0 +1,270 @@
|
||||
#include "sampling.h"
|
||||
#include "log.h"
|
||||
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
|
||||
# include "llguidance.h"
|
||||
# include <cmath>
|
||||
|
||||
struct llama_sampler_llg {
|
||||
const llama_vocab * vocab;
|
||||
std::string grammar_kind;
|
||||
std::string grammar_data;
|
||||
LlgTokenizer * tokenizer;
|
||||
LlgConstraint * grammar;
|
||||
LlgMaskResult llg_res;
|
||||
bool has_llg_res;
|
||||
};
|
||||
|
||||
static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
|
||||
const char * grammar_data) {
|
||||
LlgConstraintInit cinit;
|
||||
llg_constraint_init_set_defaults(&cinit, tokenizer);
|
||||
const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
|
||||
if (log_level && *log_level) {
|
||||
cinit.log_stderr_level = atoi(log_level);
|
||||
}
|
||||
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
|
||||
if (llg_get_error(c)) {
|
||||
LOG_ERR("llg error: %s\n", llg_get_error(c));
|
||||
llg_free_constraint(c);
|
||||
return nullptr;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
|
||||
return "llguidance";
|
||||
}
|
||||
|
||||
static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
|
||||
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
||||
if (ctx->grammar) {
|
||||
LlgCommitResult res;
|
||||
llg_commit_token(ctx->grammar, token, &res);
|
||||
ctx->has_llg_res = false;
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
||||
if (ctx->grammar) {
|
||||
if (!ctx->has_llg_res) {
|
||||
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
|
||||
ctx->has_llg_res = true;
|
||||
} else {
|
||||
LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
|
||||
llg_free_constraint(ctx->grammar);
|
||||
ctx->grammar = nullptr;
|
||||
}
|
||||
}
|
||||
if (ctx->has_llg_res) {
|
||||
if (ctx->llg_res.is_stop) {
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const uint32_t * mask = ctx->llg_res.sample_mask;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
auto token = cur_p->data[i].id;
|
||||
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_llg_reset(llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
||||
if (!ctx->grammar) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
|
||||
llg_free_constraint(ctx->grammar);
|
||||
ctx->grammar = grammar_new;
|
||||
ctx->has_llg_res = false;
|
||||
}
|
||||
|
||||
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
|
||||
|
||||
auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
|
||||
|
||||
// copy the state
|
||||
{
|
||||
auto * result_ctx = (llama_sampler_llg *) result->ctx;
|
||||
|
||||
if (ctx->grammar) {
|
||||
result_ctx->grammar_kind = ctx->grammar_kind;
|
||||
result_ctx->grammar_data = ctx->grammar_data;
|
||||
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
|
||||
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static void llama_sampler_llg_free(llama_sampler * smpl) {
|
||||
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
||||
|
||||
if (ctx->grammar) {
|
||||
llg_free_constraint(ctx->grammar);
|
||||
llg_free_tokenizer(ctx->tokenizer);
|
||||
}
|
||||
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
static llama_sampler_i llama_sampler_llg_i = {
|
||||
/* .name = */ llama_sampler_llg_name,
|
||||
/* .accept = */ llama_sampler_llg_accept_impl,
|
||||
/* .apply = */ llama_sampler_llg_apply,
|
||||
/* .reset = */ llama_sampler_llg_reset,
|
||||
/* .clone = */ llama_sampler_llg_clone,
|
||||
/* .free = */ llama_sampler_llg_free,
|
||||
};
|
||||
|
||||
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
|
||||
uint32_t * output_tokens, size_t output_tokens_len) {
|
||||
const llama_vocab * vocab = (const llama_vocab *) user_data;
|
||||
int r = 0;
|
||||
try {
|
||||
r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
|
||||
true);
|
||||
} catch (const std::exception & e) {
|
||||
GGML_ABORT("llama_tokenize failed: %s\n", e.what());
|
||||
}
|
||||
if (r < 0) {
|
||||
return -r;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
|
||||
// TODO store the tokenizer in the vocab somehow
|
||||
static const llama_vocab * vocab_cache;
|
||||
static LlgTokenizer * tokenizer_cache;
|
||||
|
||||
if (vocab_cache == vocab) {
|
||||
return llg_clone_tokenizer(tokenizer_cache);
|
||||
}
|
||||
|
||||
auto tok_eos = llama_vocab_eot(vocab);
|
||||
if (tok_eos == LLAMA_TOKEN_NULL) {
|
||||
tok_eos = llama_vocab_eos(vocab);
|
||||
}
|
||||
|
||||
size_t vocab_size = llama_vocab_n_tokens(vocab);
|
||||
|
||||
auto token_lens = new uint32_t[vocab_size];
|
||||
// we typically have ~7 bytes per token; let's go on the safe side here
|
||||
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
|
||||
auto token_bytes = new uint8_t[token_bytes_size];
|
||||
|
||||
size_t offset = 0;
|
||||
for (size_t i = 0; i < vocab_size; i++) {
|
||||
size_t max_token = 1024;
|
||||
if (token_bytes_size - offset < max_token) {
|
||||
GGML_ABORT("token_bytes buffer too small\n");
|
||||
}
|
||||
|
||||
llama_token token = i;
|
||||
auto dp = (char *) token_bytes + offset;
|
||||
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
|
||||
if (size < 0) {
|
||||
GGML_ABORT("llama_detokenize failed\n");
|
||||
}
|
||||
if (size == 0) {
|
||||
size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
|
||||
if (size < 0) {
|
||||
GGML_ABORT("llama_detokenize failed\n");
|
||||
}
|
||||
if (size != 0) {
|
||||
*dp = '\xff'; // special token prefix marker
|
||||
size += 1;
|
||||
}
|
||||
}
|
||||
|
||||
token_lens[i] = size;
|
||||
offset += size;
|
||||
}
|
||||
|
||||
LlgTokenizerInit tinit = {
|
||||
/* .vocab_size = */ (uint32_t) vocab_size,
|
||||
/* .tok_eos = */ (uint32_t) tok_eos,
|
||||
/* .token_lens = */ token_lens,
|
||||
/* .token_bytes = */ token_bytes,
|
||||
/* .tokenizer_json = */ nullptr,
|
||||
/* .tokenize_assumes_string = */ true,
|
||||
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
|
||||
/* .use_approximate_greedy_tokenize_fn = */ false,
|
||||
/* .tokenize_user_data = */ vocab,
|
||||
};
|
||||
|
||||
char error_buffer[1024];
|
||||
LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
|
||||
|
||||
delete[] token_bytes;
|
||||
delete[] token_lens;
|
||||
|
||||
if (tokenizer == nullptr) {
|
||||
LOG_ERR("llg tokenizer error: %s\n", error_buffer);
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
if (tokenizer_cache) {
|
||||
llg_free_tokenizer(tokenizer_cache);
|
||||
}
|
||||
vocab_cache = vocab;
|
||||
tokenizer_cache = tokenizer;
|
||||
|
||||
return llg_clone_tokenizer(tokenizer_cache);
|
||||
}
|
||||
|
||||
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
|
||||
const char * grammar_data) {
|
||||
auto * ctx = new llama_sampler_llg;
|
||||
|
||||
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
|
||||
auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
|
||||
*ctx = {
|
||||
/* .vocab = */ vocab,
|
||||
/* .grammar_kind = */ grammar_kind,
|
||||
/* .grammar_data = */ grammar_data,
|
||||
/* .tokenizer = */ tokenizer,
|
||||
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
|
||||
/* .llg_res = */ {},
|
||||
/* .has_llg_res = */ false,
|
||||
};
|
||||
} else {
|
||||
*ctx = {
|
||||
/* .vocab = */ vocab,
|
||||
/* .grammar_kind = */ {},
|
||||
/* .grammar_data = */ {},
|
||||
/* .tokenizer = */ nullptr,
|
||||
/* .grammar = */ nullptr,
|
||||
/* .llg_res = */ {},
|
||||
/* .has_llg_res = */ false,
|
||||
};
|
||||
}
|
||||
|
||||
return new llama_sampler{
|
||||
/* .iface = */ &llama_sampler_llg_i,
|
||||
/* .ctx = */ ctx,
|
||||
};
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
|
||||
LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
@@ -14,16 +14,6 @@ void common_log_set_verbosity_thold(int verbosity) {
|
||||
common_log_verbosity_thold = verbosity;
|
||||
}
|
||||
|
||||
#define LOG_COL_DEFAULT "\033[0m"
|
||||
#define LOG_COL_BOLD "\033[1m"
|
||||
#define LOG_COL_RED "\033[31m"
|
||||
#define LOG_COL_GREEN "\033[32m"
|
||||
#define LOG_COL_YELLOW "\033[33m"
|
||||
#define LOG_COL_BLUE "\033[34m"
|
||||
#define LOG_COL_MAGENTA "\033[35m"
|
||||
#define LOG_COL_CYAN "\033[36m"
|
||||
#define LOG_COL_WHITE "\033[37m"
|
||||
|
||||
static int64_t t_us() {
|
||||
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
}
|
||||
@@ -206,6 +196,7 @@ public:
|
||||
vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy);
|
||||
}
|
||||
#endif
|
||||
va_end(args_copy);
|
||||
}
|
||||
|
||||
entry.level = level;
|
||||
|
||||
10
common/log.h
10
common/log.h
@@ -2,6 +2,16 @@
|
||||
|
||||
#include "ggml.h" // for ggml_log_level
|
||||
|
||||
#define LOG_COL_DEFAULT "\033[0m"
|
||||
#define LOG_COL_BOLD "\033[1m"
|
||||
#define LOG_COL_RED "\033[31m"
|
||||
#define LOG_COL_GREEN "\033[32m"
|
||||
#define LOG_COL_YELLOW "\033[33m"
|
||||
#define LOG_COL_BLUE "\033[34m"
|
||||
#define LOG_COL_MAGENTA "\033[35m"
|
||||
#define LOG_COL_CYAN "\033[36m"
|
||||
#define LOG_COL_WHITE "\033[37m"
|
||||
|
||||
#ifndef __GNUC__
|
||||
# define LOG_ATTRIBUTE_FORMAT(...)
|
||||
#elif defined(__MINGW32__)
|
||||
|
||||
174
common/minja.hpp
174
common/minja.hpp
@@ -628,7 +628,7 @@ class Context : public std::enable_shared_from_this<Context> {
|
||||
if (parent_) return parent_->contains(key);
|
||||
return false;
|
||||
}
|
||||
virtual void set(const Value & key, Value & value) {
|
||||
virtual void set(const Value & key, const Value & value) {
|
||||
values_.set(key, value);
|
||||
}
|
||||
};
|
||||
@@ -693,7 +693,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };
|
||||
|
||||
class TemplateToken {
|
||||
public:
|
||||
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter };
|
||||
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue };
|
||||
|
||||
static std::string typeToString(Type t) {
|
||||
switch (t) {
|
||||
@@ -714,6 +714,8 @@ public:
|
||||
case Type::EndFilter: return "endfilter";
|
||||
case Type::Generation: return "generation";
|
||||
case Type::EndGeneration: return "endgeneration";
|
||||
case Type::Break: return "break";
|
||||
case Type::Continue: return "continue";
|
||||
}
|
||||
return "Unknown";
|
||||
}
|
||||
@@ -815,6 +817,22 @@ struct CommentTemplateToken : public TemplateToken {
|
||||
CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {}
|
||||
};
|
||||
|
||||
enum class LoopControlType { Break, Continue };
|
||||
|
||||
class LoopControlException : public std::runtime_error {
|
||||
public:
|
||||
LoopControlType control_type;
|
||||
LoopControlException(const std::string & message, LoopControlType control_type) : std::runtime_error(message), control_type(control_type) {}
|
||||
LoopControlException(LoopControlType control_type)
|
||||
: std::runtime_error((control_type == LoopControlType::Continue ? "continue" : "break") + std::string(" outside of a loop")),
|
||||
control_type(control_type) {}
|
||||
};
|
||||
|
||||
struct LoopControlTemplateToken : public TemplateToken {
|
||||
LoopControlType control_type;
|
||||
LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {}
|
||||
};
|
||||
|
||||
class TemplateNode {
|
||||
Location location_;
|
||||
protected:
|
||||
@@ -825,6 +843,12 @@ public:
|
||||
void render(std::ostringstream & out, const std::shared_ptr<Context> & context) const {
|
||||
try {
|
||||
do_render(out, context);
|
||||
} catch (const LoopControlException & e) {
|
||||
// TODO: make stack creation lazy. Only needed if it was thrown outside of a loop.
|
||||
std::ostringstream err;
|
||||
err << e.what();
|
||||
if (location_.source) err << error_location_suffix(*location_.source, location_.pos);
|
||||
throw LoopControlException(err.str(), e.control_type);
|
||||
} catch (const std::exception & e) {
|
||||
std::ostringstream err;
|
||||
err << e.what();
|
||||
@@ -897,6 +921,15 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class LoopControlNode : public TemplateNode {
|
||||
LoopControlType control_type_;
|
||||
public:
|
||||
LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {}
|
||||
void do_render(std::ostringstream &, const std::shared_ptr<Context> &) const override {
|
||||
throw LoopControlException(control_type_);
|
||||
}
|
||||
};
|
||||
|
||||
class ForNode : public TemplateNode {
|
||||
std::vector<std::string> var_names;
|
||||
std::shared_ptr<Expression> iterable;
|
||||
@@ -961,7 +994,12 @@ public:
|
||||
loop.set("last", i == (n - 1));
|
||||
loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value());
|
||||
loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value());
|
||||
body->render(out, loop_context);
|
||||
try {
|
||||
body->render(out, loop_context);
|
||||
} catch (const LoopControlException & e) {
|
||||
if (e.control_type == LoopControlType::Break) break;
|
||||
if (e.control_type == LoopControlType::Continue) continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -2156,10 +2194,10 @@ private:
|
||||
}
|
||||
|
||||
TemplateTokenVector tokenize() {
|
||||
static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})");
|
||||
static std::regex comment_tok(R"(\{#([-~]?)([\s\S\r\n]*?)([-~]?)#\})");
|
||||
static std::regex expr_open_regex(R"(\{\{([-~])?)");
|
||||
static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
|
||||
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)");
|
||||
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
|
||||
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
|
||||
static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
|
||||
static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");
|
||||
@@ -2291,6 +2329,9 @@ private:
|
||||
} else if (keyword == "endfilter") {
|
||||
auto post_space = parseBlockClose();
|
||||
tokens.push_back(std::make_unique<EndFilterTemplateToken>(location, pre_space, post_space));
|
||||
} else if (keyword == "break" || keyword == "continue") {
|
||||
auto post_space = parseBlockClose();
|
||||
tokens.push_back(std::make_unique<LoopControlTemplateToken>(location, pre_space, post_space, keyword == "break" ? LoopControlType::Break : LoopControlType::Continue));
|
||||
} else {
|
||||
throw std::runtime_error("Unexpected block: " + keyword);
|
||||
}
|
||||
@@ -2414,6 +2455,8 @@ private:
|
||||
children.emplace_back(std::make_shared<FilterNode>(token->location, std::move(filter_token->filter), std::move(body)));
|
||||
} else if (dynamic_cast<CommentTemplateToken*>(token.get())) {
|
||||
// Ignore comments
|
||||
} else if (auto ctrl_token = dynamic_cast<LoopControlTemplateToken*>(token.get())) {
|
||||
children.emplace_back(std::make_shared<LoopControlNode>(token->location, ctrl_token->control_type));
|
||||
} else if (dynamic_cast<EndForTemplateToken*>(token.get())
|
||||
|| dynamic_cast<EndSetTemplateToken*>(token.get())
|
||||
|| dynamic_cast<EndMacroTemplateToken*>(token.get())
|
||||
@@ -2572,6 +2615,7 @@ inline std::shared_ptr<Context> Context::builtins() {
|
||||
}));
|
||||
globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr<Context> &, Value & args) {
|
||||
auto do_join = [](Value & items, const std::string & sep) {
|
||||
if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
|
||||
std::ostringstream oss;
|
||||
auto first = true;
|
||||
for (size_t i = 0, n = items.size(); i < n; ++i) {
|
||||
@@ -2648,31 +2692,38 @@ inline std::shared_ptr<Context> Context::builtins() {
|
||||
return filter.call(context, actual_args);
|
||||
});
|
||||
};
|
||||
// https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject
|
||||
globals.set("reject", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
||||
args.expectArgs("reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
|
||||
auto & items = args.args[0];
|
||||
auto filter_fn = context->get(args.args[1]);
|
||||
if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
|
||||
auto select_or_reject = [make_filter](bool is_select) {
|
||||
return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
||||
args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
|
||||
auto & items = args.args[0];
|
||||
if (items.is_null())
|
||||
return Value::array();
|
||||
if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
|
||||
|
||||
auto filter_args = Value::array();
|
||||
for (size_t i = 2, n = args.args.size(); i < n; i++) {
|
||||
filter_args.push_back(args.args[i]);
|
||||
}
|
||||
auto filter = make_filter(filter_fn, filter_args);
|
||||
auto filter_fn = context->get(args.args[1]);
|
||||
if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
|
||||
|
||||
auto res = Value::array();
|
||||
for (size_t i = 0, n = items.size(); i < n; i++) {
|
||||
auto & item = items.at(i);
|
||||
ArgumentsValue filter_args;
|
||||
filter_args.args.emplace_back(item);
|
||||
auto pred_res = filter.call(context, filter_args);
|
||||
if (!pred_res.to_bool()) {
|
||||
res.push_back(item);
|
||||
auto filter_args = Value::array();
|
||||
for (size_t i = 2, n = args.args.size(); i < n; i++) {
|
||||
filter_args.push_back(args.args[i]);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}));
|
||||
auto filter = make_filter(filter_fn, filter_args);
|
||||
|
||||
auto res = Value::array();
|
||||
for (size_t i = 0, n = items.size(); i < n; i++) {
|
||||
auto & item = items.at(i);
|
||||
ArgumentsValue filter_args;
|
||||
filter_args.args.emplace_back(item);
|
||||
auto pred_res = filter.call(context, filter_args);
|
||||
if (pred_res.to_bool() == (is_select ? true : false)) {
|
||||
res.push_back(item);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
});
|
||||
};
|
||||
globals.set("select", select_or_reject(/* is_select= */ true));
|
||||
globals.set("reject", select_or_reject(/* is_select= */ false));
|
||||
globals.set("map", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
||||
auto res = Value::array();
|
||||
if (args.args.size() == 1 &&
|
||||
@@ -2720,41 +2771,46 @@ inline std::shared_ptr<Context> Context::builtins() {
|
||||
if (!text.empty() && text.back() == '\n') out += "\n";
|
||||
return out;
|
||||
}));
|
||||
globals.set("selectattr", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
||||
args.expectArgs("selectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
|
||||
auto & items = args.args[0];
|
||||
if (items.is_null())
|
||||
return Value::array();
|
||||
auto attr_name = args.args[1].get<std::string>();
|
||||
auto select_or_reject_attr = [](bool is_select) {
|
||||
return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
||||
args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
|
||||
auto & items = args.args[0];
|
||||
if (items.is_null())
|
||||
return Value::array();
|
||||
if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
|
||||
auto attr_name = args.args[1].get<std::string>();
|
||||
|
||||
bool has_test = false;
|
||||
Value test_fn;
|
||||
ArgumentsValue test_args {{Value()}, {}};
|
||||
if (args.args.size() >= 3) {
|
||||
has_test = true;
|
||||
test_fn = context->get(args.args[2]);
|
||||
if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
|
||||
for (size_t i = 3, n = args.args.size(); i < n; i++) {
|
||||
test_args.args.emplace_back(args.args[i]);
|
||||
}
|
||||
test_args.kwargs = args.kwargs;
|
||||
}
|
||||
|
||||
auto res = Value::array();
|
||||
for (size_t i = 0, n = items.size(); i < n; i++) {
|
||||
auto & item = items.at(i);
|
||||
auto attr = item.get(attr_name);
|
||||
if (has_test) {
|
||||
test_args.args[0] = attr;
|
||||
if (test_fn.call(context, test_args).to_bool()) {
|
||||
res.push_back(item);
|
||||
bool has_test = false;
|
||||
Value test_fn;
|
||||
ArgumentsValue test_args {{Value()}, {}};
|
||||
if (args.args.size() >= 3) {
|
||||
has_test = true;
|
||||
test_fn = context->get(args.args[2]);
|
||||
if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
|
||||
for (size_t i = 3, n = args.args.size(); i < n; i++) {
|
||||
test_args.args.emplace_back(args.args[i]);
|
||||
}
|
||||
} else {
|
||||
res.push_back(attr);
|
||||
test_args.kwargs = args.kwargs;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}));
|
||||
|
||||
auto res = Value::array();
|
||||
for (size_t i = 0, n = items.size(); i < n; i++) {
|
||||
auto & item = items.at(i);
|
||||
auto attr = item.get(attr_name);
|
||||
if (has_test) {
|
||||
test_args.args[0] = attr;
|
||||
if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) {
|
||||
res.push_back(item);
|
||||
}
|
||||
} else {
|
||||
res.push_back(attr);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
});
|
||||
};
|
||||
globals.set("selectattr", select_or_reject_attr(/* is_select= */ true));
|
||||
globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false));
|
||||
globals.set("range", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
|
||||
std::vector<int64_t> startEndStep(3);
|
||||
std::vector<bool> param_set(3);
|
||||
|
||||
@@ -151,9 +151,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
|
||||
lparams.no_perf = params.no_perf;
|
||||
|
||||
std::vector<const char *> trigger_words;
|
||||
trigger_words.reserve(params.grammar_trigger_words.size());
|
||||
for (const auto & str : params.grammar_trigger_words) {
|
||||
trigger_words.push_back(str.word.c_str());
|
||||
}
|
||||
|
||||
struct llama_sampler * grmr;
|
||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||
#else
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
} else {
|
||||
grmr = params.grammar_lazy
|
||||
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
|
||||
trigger_words.data(), trigger_words.size(),
|
||||
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
|
||||
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"),
|
||||
/* .grmr = */ grmr,
|
||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
|
||||
@@ -102,3 +102,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
|
||||
|
||||
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
||||
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|
||||
|
||||
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
|
||||
const char * grammar_kind, const char * grammar_data);
|
||||
|
||||
@@ -648,7 +648,7 @@ class Model:
|
||||
if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
|
||||
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
|
||||
res = "jina-v2-code"
|
||||
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
|
||||
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b" or chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516":
|
||||
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
|
||||
res = "chatglm-bpe"
|
||||
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
|
||||
@@ -4513,7 +4513,7 @@ class JaisModel(Model):
|
||||
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
|
||||
|
||||
|
||||
@Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration")
|
||||
@Model.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
|
||||
class ChatGLMModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.CHATGLM
|
||||
|
||||
@@ -4619,47 +4619,15 @@ class ChatGLMModel(Model):
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
vocab_size = hparams["padded_vocab_size"]
|
||||
vocab_size = hparams.get("padded_vocab_size",hparams["vocab_size"])
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.mergeable_ranks
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[ChatGLMModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
continue
|
||||
merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank)
|
||||
assert len(merged) >= 2 and len(merged) <= 7
|
||||
merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged)))
|
||||
|
||||
# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
elif reverse_vocab[i] in added_vocab:
|
||||
tokens.append(reverse_vocab[i])
|
||||
if tokenizer.added_tokens_decoder[i].special:
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
tokens.append(reverse_vocab[i])
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
|
||||
special_vocab.merges = merges
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
# only add special tokens when they were not already loaded from config.json
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
|
||||
@@ -4670,16 +4638,20 @@ class ChatGLMModel(Model):
|
||||
def set_gguf_parameters(self):
|
||||
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
|
||||
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
|
||||
n_head_kv = self.hparams.get("multi_query_group_num", n_head)
|
||||
n_head_kv = self.hparams.get("multi_query_group_num", self.hparams.get("num_key_value_heads", n_head))
|
||||
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
|
||||
self.gguf_writer.add_embedding_length(n_embed)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed))
|
||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed)))
|
||||
self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"]))
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5))
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_rope_dimension_count(64)
|
||||
if "attention_dim" in self.hparams:
|
||||
rope_dim = self.hparams["attention_dim"]
|
||||
else:
|
||||
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
rope_freq = 10000
|
||||
if "rope_ratio" in self.hparams:
|
||||
@@ -4689,7 +4661,7 @@ class ChatGLMModel(Model):
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
if name.endswith(".rotary_pos_emb.inv_freq"):
|
||||
if name.endswith(".rotary_pos_emb.inv_freq") or name.startswith("model.vision."):
|
||||
return []
|
||||
|
||||
name = name.removeprefix("transformer.")
|
||||
|
||||
@@ -133,7 +133,7 @@ The docker build option is currently limited to *intel GPU* targets.
|
||||
### Build image
|
||||
```sh
|
||||
# Using FP16
|
||||
docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" -f .devops/llama-cli-intel.Dockerfile .
|
||||
docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile .
|
||||
```
|
||||
|
||||
*Notes*:
|
||||
|
||||
@@ -286,7 +286,7 @@ You don't need to install Vulkan SDK. It will be installed inside the container.
|
||||
|
||||
```sh
|
||||
# Build the image
|
||||
docker build -t llama-cpp-vulkan -f .devops/llama-cli-vulkan.Dockerfile .
|
||||
docker build -t llama-cpp-vulkan --target light -f .devops/vulkan.Dockerfile .
|
||||
|
||||
# Then, use it:
|
||||
docker run -it --rm -v "$(pwd):/app:Z" --device /dev/dri/renderD128:/dev/dri/renderD128 --device /dev/dri/card1:/dev/dri/card1 llama-cpp-vulkan -m "/app/models/YOUR_MODEL_FILE" -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33
|
||||
|
||||
@@ -60,9 +60,9 @@ Assuming one has the [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia
|
||||
## Building Docker locally
|
||||
|
||||
```bash
|
||||
docker build -t local/llama.cpp:full-cuda -f .devops/full-cuda.Dockerfile .
|
||||
docker build -t local/llama.cpp:light-cuda -f .devops/llama-cli-cuda.Dockerfile .
|
||||
docker build -t local/llama.cpp:server-cuda -f .devops/llama-server-cuda.Dockerfile .
|
||||
docker build -t local/llama.cpp:full-cuda --target full -f .devops/cuda.Dockerfile .
|
||||
docker build -t local/llama.cpp:light-cuda --target light -f .devops/cuda.Dockerfile .
|
||||
docker build -t local/llama.cpp:server-cuda --target server -f .devops/cuda.Dockerfile .
|
||||
```
|
||||
|
||||
You may want to pass in some different `ARGS`, depending on the CUDA environment supported by your container host, as well as the GPU architecture.
|
||||
@@ -95,9 +95,9 @@ Assuming one has the [mt-container-toolkit](https://developer.mthreads.com/musa/
|
||||
## Building Docker locally
|
||||
|
||||
```bash
|
||||
docker build -t local/llama.cpp:full-musa -f .devops/full-musa.Dockerfile .
|
||||
docker build -t local/llama.cpp:light-musa -f .devops/llama-cli-musa.Dockerfile .
|
||||
docker build -t local/llama.cpp:server-musa -f .devops/llama-server-musa.Dockerfile .
|
||||
docker build -t local/llama.cpp:full-musa --target full -f .devops/musa.Dockerfile .
|
||||
docker build -t local/llama.cpp:light-musa --target light -f .devops/musa.Dockerfile .
|
||||
docker build -t local/llama.cpp:server-musa --target server -f .devops/musa.Dockerfile .
|
||||
```
|
||||
|
||||
You may want to pass in some different `ARGS`, depending on the MUSA environment supported by your container host, as well as the GPU architecture.
|
||||
|
||||
51
docs/llguidance.md
Normal file
51
docs/llguidance.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# LLGuidance Support in llama.cpp
|
||||
|
||||
[LLGuidance](https://github.com/guidance-ai/llguidance) is a library for constrained decoding (also called constrained sampling or structured outputs) for Large Language Models (LLMs). Initially developed as the backend for the [Guidance](https://github.com/guidance-ai/guidance) library, it can also be used independently.
|
||||
|
||||
LLGuidance supports JSON Schemas and arbitrary context-free grammars (CFGs) written in a [variant](https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md) of Lark syntax. It is [very fast](https://github.com/guidance-ai/jsonschemabench/tree/main/maskbench) and has [excellent](https://github.com/guidance-ai/llguidance/blob/main/docs/json_schema.md) JSON Schema coverage but requires the Rust compiler, which complicates the llama.cpp build process.
|
||||
|
||||
## Building
|
||||
|
||||
To enable LLGuidance support, build llama.cpp with the `LLAMA_LLGUIDANCE` option:
|
||||
|
||||
```sh
|
||||
cmake -B build -DLLAMA_LLGUIDANCE=ON
|
||||
make -C build -j
|
||||
```
|
||||
|
||||
This requires the Rust compiler and the `cargo` tool to be [installed](https://www.rust-lang.org/tools/install).
|
||||
|
||||
## Interface
|
||||
|
||||
There are no new command-line arguments or modifications to `common_params`. When enabled, grammars starting with `%llguidance` are passed to LLGuidance instead of the [current](../grammars/README.md) llama.cpp grammars. Additionally, JSON Schema requests (e.g., using the `-j` argument in `llama-cli`) are also passed to LLGuidance.
|
||||
|
||||
For your existing GBNF grammars, you can use [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/scripts/gbnf_to_lark.py) to convert them to LLGuidance Lark-like format.
|
||||
|
||||
## Performance
|
||||
|
||||
Computing a "token mask" (i.e., the set of allowed tokens) for a llama3 tokenizer with 128k tokens takes, on average, 50μs of single-core CPU time for the [JSON Schema Bench](https://github.com/guidance-ai/jsonschemabench). The p99 time is 0.5ms, and the p100 time is 20ms. These results are due to the lexer/parser split and several [optimizations](https://github.com/guidance-ai/llguidance/blob/main/docs/optimizations.md).
|
||||
|
||||
## JSON Schema
|
||||
|
||||
LLGuidance adheres closely to the JSON Schema specification. For example:
|
||||
|
||||
- `additionalProperties` defaults to `true`, unlike current grammars, though you can set `"additionalProperties": false` if needed.
|
||||
- any whitespace is allowed.
|
||||
- The definition order in the `"properties": {}` object is maintained, regardless of whether properties are required (current grammars always puts required properties first).
|
||||
|
||||
Unsupported schemas result in an error message—no keywords are silently ignored.
|
||||
|
||||
## Why Not Reuse GBNF Format?
|
||||
|
||||
GBNF lacks the concept of a lexer.
|
||||
|
||||
Most programming languages, including JSON, use a two-step process: a lexer (built with regular expressions) converts a byte stream into lexemes, which are then processed by a CFG parser. This approach is faster because lexers are cheaper to evaluate, and there is ~10x fewer lexemes than bytes.
|
||||
LLM tokens often align with lexemes, so the parser is engaged in under 0.5% of tokens, with the lexer handling the rest.
|
||||
|
||||
However, the user has to provide the distinction between lexemes and CFG symbols. In [Lark](https://github.com/lark-parser/lark), lexeme names are uppercase, while CFG symbols are lowercase.
|
||||
The [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/scripts/gbnf_to_lark.py) can often take care of this automatically.
|
||||
See [LLGuidance syntax docs](https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#terminals-vs-rules) for more details.
|
||||
|
||||
## Error Handling
|
||||
|
||||
Errors are currently printed to `stderr`, and generation continues. Improved error handling may be added in the future.
|
||||
@@ -31,6 +31,11 @@ defer {
|
||||
llama_model_free(model)
|
||||
}
|
||||
|
||||
guard let vocab = llama_model_get_vocab(model) else {
|
||||
print("Failed to get vocab")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
var tokens = tokenize(text: prompt, add_bos: true)
|
||||
|
||||
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
|
||||
@@ -41,7 +46,7 @@ context_params.n_batch = UInt32(max(n_len, n_parallel))
|
||||
context_params.n_threads = 8
|
||||
context_params.n_threads_batch = 8
|
||||
|
||||
let context = llama_new_context_with_model(model, context_params)
|
||||
let context = llama_init_from_model(model, context_params)
|
||||
guard context != nil else {
|
||||
print("Failed to initialize context")
|
||||
exit(1)
|
||||
@@ -141,7 +146,7 @@ while n_cur <= n_len {
|
||||
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
|
||||
|
||||
// is it an end of stream? -> mark the stream as finished
|
||||
if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
|
||||
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
|
||||
i_batch[i] = -1
|
||||
// print("")
|
||||
if n_parallel > 1 {
|
||||
@@ -207,7 +212,7 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
||||
let utf8Count = text.utf8.count
|
||||
let n_tokens = utf8Count + (add_bos ? 1 : 0)
|
||||
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
|
||||
let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
|
||||
let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
|
||||
var swiftTokens: [llama_token] = []
|
||||
for i in 0 ..< tokenCount {
|
||||
swiftTokens.append(tokens[Int(i)])
|
||||
@@ -218,12 +223,12 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
||||
|
||||
private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? {
|
||||
var result = [CChar](repeating: 0, count: 8)
|
||||
let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count), 0, false)
|
||||
let nTokens = llama_token_to_piece(vocab, token, &result, Int32(result.count), 0, false)
|
||||
if nTokens < 0 {
|
||||
let actualTokensCount = -Int(nTokens)
|
||||
result = .init(repeating: 0, count: actualTokensCount)
|
||||
let check = llama_token_to_piece(
|
||||
model,
|
||||
vocab,
|
||||
token,
|
||||
&result,
|
||||
Int32(result.count),
|
||||
|
||||
@@ -76,7 +76,7 @@ int main(int argc, char** argv) {
|
||||
grammar_str = buffer.str();
|
||||
}
|
||||
|
||||
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
|
||||
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0);
|
||||
if (grammar == nullptr) {
|
||||
fprintf(stdout, "Failed to initialize llama_grammar\n");
|
||||
return 1;
|
||||
|
||||
@@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
|
||||
actor LlamaContext {
|
||||
private var model: OpaquePointer
|
||||
private var context: OpaquePointer
|
||||
private var vocab: OpaquePointer
|
||||
private var sampling: UnsafeMutablePointer<llama_sampler>
|
||||
private var batch: llama_batch
|
||||
private var tokens_list: [llama_token]
|
||||
@@ -47,6 +48,7 @@ actor LlamaContext {
|
||||
self.sampling = llama_sampler_chain_init(sparams)
|
||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
|
||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
|
||||
vocab = llama_model_get_vocab(model)
|
||||
}
|
||||
|
||||
deinit {
|
||||
@@ -79,7 +81,7 @@ actor LlamaContext {
|
||||
ctx_params.n_threads = Int32(n_threads)
|
||||
ctx_params.n_threads_batch = Int32(n_threads)
|
||||
|
||||
let context = llama_new_context_with_model(model, ctx_params)
|
||||
let context = llama_init_from_model(model, ctx_params)
|
||||
guard let context else {
|
||||
print("Could not load context!")
|
||||
throw LlamaError.couldNotInitializeContext
|
||||
@@ -151,7 +153,7 @@ actor LlamaContext {
|
||||
|
||||
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
|
||||
|
||||
if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
|
||||
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
|
||||
print("\n")
|
||||
is_done = true
|
||||
let new_token_str = String(cString: temporary_invalid_cchars + [0])
|
||||
@@ -297,7 +299,7 @@ actor LlamaContext {
|
||||
let utf8Count = text.utf8.count
|
||||
let n_tokens = utf8Count + (add_bos ? 1 : 0) + 1
|
||||
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
|
||||
let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)
|
||||
let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)
|
||||
|
||||
var swiftTokens: [llama_token] = []
|
||||
for i in 0..<tokenCount {
|
||||
@@ -316,7 +318,7 @@ actor LlamaContext {
|
||||
defer {
|
||||
result.deallocate()
|
||||
}
|
||||
let nTokens = llama_token_to_piece(model, token, result, 8, 0, false)
|
||||
let nTokens = llama_token_to_piece(vocab, token, result, 8, 0, false)
|
||||
|
||||
if nTokens < 0 {
|
||||
let newResult = UnsafeMutablePointer<Int8>.allocate(capacity: Int(-nTokens))
|
||||
@@ -324,7 +326,7 @@ actor LlamaContext {
|
||||
defer {
|
||||
newResult.deallocate()
|
||||
}
|
||||
let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens, 0, false)
|
||||
let nNewTokens = llama_token_to_piece(vocab, token, newResult, -nTokens, 0, false)
|
||||
let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
|
||||
return Array(bufferPointer)
|
||||
} else {
|
||||
|
||||
@@ -50,3 +50,10 @@ set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-qwen2vl-cli)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
set(TARGET llama-llava-clip-quantize-cli)
|
||||
add_executable(${TARGET} clip-quantize-cli.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize-cli)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
43
examples/llava/README-glmedge.md
Normal file
43
examples/llava/README-glmedge.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# GLMV-EDGE
|
||||
|
||||
Currently this implementation supports [glm-edge-v-2b](https://huggingface.co/THUDM/glm-edge-v-2b) and [glm-edge-v-5b](https://huggingface.co/THUDM/glm-edge-v-5b).
|
||||
|
||||
## Usage
|
||||
Build with cmake or run `make llama-llava-cli` to build it.
|
||||
|
||||
After building, run: `./llama-llava-cli` to see the usage. For example:
|
||||
|
||||
```sh
|
||||
./llama-llava-cli -m model_path/ggml-model-f16.gguf --mmproj model_path/mmproj-model-f16.gguf --image img_path/image.jpg -p "<|system|>\n system prompt <image><|user|>\n prompt <|assistant|>\n"
|
||||
```
|
||||
|
||||
**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so.
|
||||
**note**: For GPU offloading ensure to use the `-ngl` flag just like usual
|
||||
|
||||
## GGUF conversion
|
||||
|
||||
1. Clone a GLMV-EDGE model ([2B](https://huggingface.co/THUDM/glm-edge-v-2b) or [5B](https://huggingface.co/THUDM/glm-edge-v-5b)). For example:
|
||||
|
||||
```sh
|
||||
git clone https://huggingface.co/THUDM/glm-edge-v-5b or https://huggingface.co/THUDM/glm-edge-v-2b
|
||||
```
|
||||
|
||||
2. Use `glmedge-surgery.py` to split the GLMV-EDGE model to LLM and multimodel projector constituents:
|
||||
|
||||
```sh
|
||||
python ./examples/llava/glmedge-surgery.py -m ../model_path
|
||||
```
|
||||
|
||||
4. Use `glmedge-convert-image-encoder-to-gguf.py` to convert the GLMV-EDGE image encoder to GGUF:
|
||||
|
||||
```sh
|
||||
python ./examples/llava/glmedge-convert-image-encoder-to-gguf.py -m ../model_path --llava-projector ../model_path/glm.projector --output-dir ../model_path
|
||||
```
|
||||
|
||||
5. Use `examples/convert_hf_to_gguf.py` to convert the LLM part of GLMV-EDGE to GGUF:
|
||||
|
||||
```sh
|
||||
python convert_hf_to_gguf.py ../model_path
|
||||
```
|
||||
|
||||
Now both the LLM part and the image encoder are in the `model_path` directory.
|
||||
44
examples/llava/README-quantize.md
Normal file
44
examples/llava/README-quantize.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# Quantizing CLIP Visual Projector
|
||||
|
||||
This is the tool for quantizing the CLIP visual projector model. Quantization reduces the precision of the model's weights, which can significantly decrease the model size and improve inference speed, often with minimal impact on performance.
|
||||
|
||||
## Usage
|
||||
|
||||
To quantize a CLIP visual projector model, use the following command:
|
||||
|
||||
```sh
|
||||
./bin/llama-llava-clip-quantize-cli /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf <type>
|
||||
```
|
||||
|
||||
After the quantization, the visual projector can be used freely with the existing LLAVA cli (LLAVA, Qwen2VL, etc).
|
||||
|
||||
### Arguments
|
||||
|
||||
- `/path/to/ggml-model-f32.gguf`: The path to the input model file in FP32 or FP16 format.
|
||||
- `/path/to/ggml-model-quantized.gguf`: The path where the quantized model will be saved.
|
||||
- `<type>`: The quantization type to apply. This should be an integer corresponding to one of the quantization types defined in the `enum ggml_type`.
|
||||
|
||||
### Quantization Types
|
||||
|
||||
The following quantization types are supported, based on the `enum ggml_type` definition:
|
||||
|
||||
- `2` - `q4_0`: 4-bit quantization with a single scale value.
|
||||
- `3` - `q4_1`: 4-bit quantization with a separate scale value for each block.
|
||||
- `6` - `q5_0`: 5-bit quantization with a single scale value.
|
||||
- `7` - `q5_1`: 5-bit quantization with a separate scale value for each block.
|
||||
- `8` - `q8_0`: 8-bit quantization with a single scale value.
|
||||
|
||||
### Example
|
||||
|
||||
To quantize a model using the `q4_0` quantization type, you would run:
|
||||
|
||||
```sh
|
||||
./bin/llama-llava-clip-quantize-cli /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf 2
|
||||
```
|
||||
|
||||
This command will generate a quantized model at `/path/to/ggml-model-quantized.gguf` using the `q4_0` quantization method.
|
||||
|
||||
## Notes
|
||||
|
||||
- Quantization can lead to a loss in model accuracy, depending on the chosen quantization type. It is recommended to evaluate the quantized model's performance on your specific task to ensure it meets your requirements.
|
||||
- The quantized model will typically be smaller in size and faster to run, making it more suitable for deployment in resource-constrained environments.
|
||||
59
examples/llava/clip-quantize-cli.cpp
Normal file
59
examples/llava/clip-quantize-cli.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
#include "arg.h"
|
||||
#include "base64.hpp"
|
||||
#include "log.h"
|
||||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "clip.h"
|
||||
#include "llava.h"
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
|
||||
static void print_usage(int argc, char ** argv) {
|
||||
(void) argc;
|
||||
|
||||
fprintf(stderr, "usage: %s /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf type\n", argv[0]);
|
||||
fprintf(stderr, " type = 2 - q4_0\n");
|
||||
fprintf(stderr, " type = 3 - q4_1\n");
|
||||
fprintf(stderr, " type = 6 - q5_0\n");
|
||||
fprintf(stderr, " type = 7 - q5_1\n");
|
||||
fprintf(stderr, " type = 8 - q8_0\n");
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
if (argc != 4) {
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string fname_inp = argv[1];
|
||||
const std::string fname_out = argv[2];
|
||||
|
||||
const int itype = atoi(argv[3]);
|
||||
|
||||
const int64_t t_main_start_us = ggml_time_us();
|
||||
|
||||
int64_t t_quantize_us = 0;
|
||||
|
||||
// load the model
|
||||
{
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
if (!clip_model_quantize(fname_inp.c_str(), fname_out.c_str(), itype)) {
|
||||
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
t_quantize_us = ggml_time_us() - t_start_us;
|
||||
}
|
||||
|
||||
// report timing
|
||||
{
|
||||
const int64_t t_main_end_us = ggml_time_us();
|
||||
|
||||
printf("\n");
|
||||
printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us / 1000.0f);
|
||||
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -102,6 +102,7 @@ static std::string format(const char * fmt, ...) {
|
||||
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
|
||||
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
|
||||
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
|
||||
#define KEY_HAS_GLM_PROJ "clip.has_glm_projector"
|
||||
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
|
||||
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
|
||||
#define KEY_USE_GELU "clip.use_gelu"
|
||||
@@ -160,6 +161,15 @@ static std::string format(const char * fmt, ...) {
|
||||
#define TN_MINICPMV_ATTN "resampler.attn.%s.%s"
|
||||
#define TN_MINICPMV_LN "resampler.ln_%s.%s"
|
||||
|
||||
#define TN_GLM_ADAPER_CONV "adapter.conv.%s"
|
||||
#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s"
|
||||
#define TN_GLM_ADAPTER_NORM_1 "adapter.linear.norm1.%s"
|
||||
#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s"
|
||||
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
|
||||
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
|
||||
#define TN_GLM_BOI_W "adapter.boi"
|
||||
#define TN_GLM_EOI_W "adapter.eoi"
|
||||
|
||||
|
||||
enum projector_type {
|
||||
PROJECTOR_TYPE_MLP,
|
||||
@@ -167,6 +177,7 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_LDP,
|
||||
PROJECTOR_TYPE_LDPV2,
|
||||
PROJECTOR_TYPE_RESAMPLER,
|
||||
PROJECTOR_TYPE_GLM_EDGE,
|
||||
PROJECTOR_TYPE_MERGER,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
@@ -176,6 +187,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_LDP, "ldp" },
|
||||
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
|
||||
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
|
||||
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
|
||||
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
|
||||
};
|
||||
|
||||
@@ -500,6 +512,12 @@ struct clip_vision_model {
|
||||
struct ggml_tensor * mm_4_w = NULL;
|
||||
struct ggml_tensor * mm_4_b = NULL;
|
||||
|
||||
//GLMV-Edge projection
|
||||
struct ggml_tensor * mm_model_adapter_conv_w;
|
||||
struct ggml_tensor * mm_model_adapter_conv_b;
|
||||
struct ggml_tensor * boi_w;
|
||||
struct ggml_tensor * eoi_w;
|
||||
|
||||
// MobileVLM projection
|
||||
struct ggml_tensor * mm_model_mlp_1_w;
|
||||
struct ggml_tensor * mm_model_mlp_1_b;
|
||||
@@ -560,6 +578,7 @@ struct clip_ctx {
|
||||
bool has_vision_encoder = false;
|
||||
bool has_llava_projector = false;
|
||||
bool has_minicpmv_projector = false;
|
||||
bool has_glm_projector = false;
|
||||
bool has_qwen2vl_merger = false;
|
||||
int minicpmv_version = 2;
|
||||
|
||||
@@ -638,7 +657,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
|
||||
const int batch_size = imgs->size;
|
||||
|
||||
if (ctx->has_llava_projector || ctx->has_minicpmv_projector) {
|
||||
if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) {
|
||||
GGML_ASSERT(batch_size == 1);
|
||||
}
|
||||
|
||||
@@ -734,8 +753,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
}
|
||||
|
||||
// loop over layers
|
||||
if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) {
|
||||
// TODO: figure out why we doing thing in this way ???
|
||||
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
|
||||
n_layer += 1;
|
||||
}
|
||||
for (int il = 0; il < n_layer - 1; il++) {
|
||||
@@ -1095,7 +1113,33 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
// glm projector
|
||||
else if (ctx->has_glm_projector) {
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
|
||||
size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
|
||||
embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
|
||||
embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
|
||||
embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
|
||||
//GLU
|
||||
{
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
|
||||
embeddings = ggml_gelu_inplace(ctx0, embeddings);
|
||||
struct ggml_tensor * x = embeddings;
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
|
||||
x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
|
||||
embeddings = ggml_silu_inplace(ctx0, embeddings);
|
||||
embeddings = ggml_mul(ctx0, embeddings,x);
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("fatel error");
|
||||
}
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
|
||||
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||
@@ -1284,6 +1328,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
|
||||
}
|
||||
|
||||
idx = gguf_find_key(ctx, KEY_HAS_GLM_PROJ);
|
||||
if (idx != -1) {
|
||||
new_clip->has_glm_projector = gguf_get_val_bool(ctx, idx);
|
||||
}
|
||||
|
||||
idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER);
|
||||
if (idx != -1) {
|
||||
new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx);
|
||||
@@ -1308,6 +1357,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
|
||||
LOG_INF("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
|
||||
LOG_INF("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector);
|
||||
LOG_INF("%s: glm_projector: %d\n", __func__, new_clip->has_glm_projector);
|
||||
LOG_INF("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0);
|
||||
LOG_INF("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
|
||||
}
|
||||
@@ -1575,6 +1625,18 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight"));
|
||||
vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias"));
|
||||
}
|
||||
else if (new_clip->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
|
||||
vision_model.mm_model_adapter_conv_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "weight"));
|
||||
vision_model.mm_model_adapter_conv_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "bias"));
|
||||
vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_LINEAR,"weight"));
|
||||
vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"weight"));
|
||||
vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"bias"));
|
||||
vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_H_2_4H,"weight"));
|
||||
vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_GATE,"weight"));
|
||||
vision_model.mm_model_mlp_3_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
|
||||
vision_model.boi_w = get_tensor(new_clip->ctx_data, TN_GLM_BOI_W);
|
||||
vision_model.eoi_w = get_tensor(new_clip->ctx_data, TN_GLM_EOI_W);
|
||||
}
|
||||
else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight"));
|
||||
vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias"));
|
||||
@@ -2115,6 +2177,20 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ctx->has_glm_projector) {
|
||||
res_imgs->size = 1;
|
||||
res_imgs->data = new clip_image_f32[res_imgs->size];
|
||||
clip_image_u8 resized_image;
|
||||
int32_t sz=ctx->vision_model.hparams.image_size;
|
||||
bicubic_resize(*img, resized_image,sz,sz);
|
||||
clip_image_f32 * res = clip_image_f32_init();
|
||||
//clip_image_save_to_bmp(resized_image, "resized.bmp");
|
||||
normalize_image_u8_to_f32(&resized_image, res, ctx->image_mean, ctx->image_std);
|
||||
res_imgs->data[0] = *res;
|
||||
clip_image_f32_free(res);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pad_to_square = true;
|
||||
if (!ctx->has_vision_encoder) {
|
||||
LOG_ERR("This gguf file seems to have no vision encoder\n");
|
||||
@@ -2300,7 +2376,8 @@ void clip_free(clip_ctx * ctx) {
|
||||
}
|
||||
|
||||
size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
|
||||
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||
int extra_tokens = ctx->has_glm_projector ? 2 : 0;
|
||||
return (clip_n_patches(ctx) + extra_tokens) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||
}
|
||||
|
||||
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
|
||||
@@ -2342,7 +2419,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
|
||||
|
||||
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
|
||||
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
|
||||
n_patches /= 4;
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
|
||||
if (ctx->minicpmv_version == 2) {
|
||||
@@ -2475,6 +2552,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
GGML_ASSERT(batch_size == 1);
|
||||
}
|
||||
if (ctx->has_glm_projector) {
|
||||
GGML_ASSERT(batch_size == 1);
|
||||
ggml_tensor * boi = ctx->vision_model.boi_w;
|
||||
ggml_backend_tensor_get(boi,vec,0,ggml_nbytes(boi));
|
||||
vec = (float*)(vec+ggml_nelements(boi)); //offset for boi
|
||||
}
|
||||
|
||||
// build the inference graph
|
||||
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true);
|
||||
@@ -2627,7 +2710,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||
free(positions_data);
|
||||
|
||||
{
|
||||
if (!ctx->has_glm_projector) {
|
||||
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
|
||||
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||
for (int i = 0; i < num_patches; i++) {
|
||||
@@ -2651,14 +2734,19 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
// copy the embeddings to the location passed by the user
|
||||
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
|
||||
|
||||
if (ctx->has_glm_projector) {
|
||||
//eoi
|
||||
ggml_tensor * eoi = ctx->vision_model.eoi_w;
|
||||
int offset = ggml_nelements(embeddings);
|
||||
ggml_backend_tensor_get(eoi, vec+offset, 0, ggml_nbytes(eoi));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) {
|
||||
ggml_type type = GGML_TYPE_Q4_1;
|
||||
|
||||
assert(itype < GGML_TYPE_COUNT);
|
||||
type = static_cast<ggml_type>(itype);
|
||||
ggml_type type = static_cast<ggml_type>(itype);
|
||||
|
||||
auto * ctx_clip = clip_model_load(fname_inp, 2);
|
||||
|
||||
@@ -2711,8 +2799,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
|
||||
}
|
||||
}
|
||||
|
||||
// quantize only 2D tensors
|
||||
quantize &= (ggml_n_dims(cur) == 2);
|
||||
// quantize only 2D tensors and bigger than block size
|
||||
quantize &= (ggml_n_dims(cur) == 2) && cur->ne[0] > ggml_blck_size(type);
|
||||
|
||||
if (quantize) {
|
||||
new_type = type;
|
||||
@@ -2812,6 +2900,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
return 3584;
|
||||
}
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE){
|
||||
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
return ctx->vision_model.mm_1_b->ne[0];
|
||||
}
|
||||
@@ -2827,6 +2918,9 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool clip_is_glm(const struct clip_ctx * ctx) {
|
||||
return ctx->has_glm_projector;
|
||||
}
|
||||
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
|
||||
return ctx->has_qwen2vl_merger;
|
||||
}
|
||||
|
||||
@@ -93,6 +93,8 @@ CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
|
||||
|
||||
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
280
examples/llava/glmedge-convert-image-encoder-to-gguf.py
Normal file
280
examples/llava/glmedge-convert-image-encoder-to-gguf.py
Normal file
@@ -0,0 +1,280 @@
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from gguf import *
|
||||
|
||||
TEXT = "clip.text"
|
||||
VISION = "clip.vision"
|
||||
from transformers import SiglipVisionModel, SiglipVisionConfig
|
||||
|
||||
def k(raw_key: str, arch: str) -> str:
|
||||
return raw_key.format(arch=arch)
|
||||
|
||||
|
||||
def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool:
|
||||
if name in (
|
||||
"logit_scale",
|
||||
"text_model.embeddings.position_ids",
|
||||
"vision_model.embeddings.position_ids",
|
||||
):
|
||||
return True
|
||||
|
||||
if name in (
|
||||
"vision_model.head.probe",
|
||||
"vision_model.head.attention.in_proj_weight",
|
||||
"vision_model.head.attention.in_proj_bias",
|
||||
"vision_model.head.attention.out_proj.weight",
|
||||
"vision_model.head.attention.out_proj.bias",
|
||||
"vision_model.head.layernorm.weight",
|
||||
"vision_model.head.layernorm.bias",
|
||||
"vision_model.head.mlp.fc1.weight",
|
||||
"vision_model.head.mlp.fc1.bias",
|
||||
"vision_model.head.mlp.fc2.weight",
|
||||
"vision_model.head.mlp.fc2.bias"
|
||||
):
|
||||
return True
|
||||
|
||||
if name.startswith("v") and not has_vision:
|
||||
return True
|
||||
|
||||
if name.startswith("t") and not has_text:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_tensor_name(name: str) -> str:
|
||||
if "projection" in name:
|
||||
return name
|
||||
if "mm_projector" in name:
|
||||
name = name.replace("model.mm_projector", "mm")
|
||||
name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1)
|
||||
name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1)
|
||||
return name
|
||||
|
||||
return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")
|
||||
|
||||
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1))
|
||||
+ list(range(ord("¡"), ord("¬") + 1))
|
||||
+ list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
|
||||
ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
|
||||
ap.add_argument("--text-only", action="store_true", required=False,
|
||||
help="Save a text-only model. It can't be used to encode images")
|
||||
ap.add_argument("--vision-only", action="store_true", required=False,
|
||||
help="Save a vision-only model. It can't be used to encode texts")
|
||||
ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
|
||||
help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
|
||||
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
|
||||
help="The clip model is from openclip (for ViT-SO400M type))")
|
||||
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
|
||||
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2","adapter"], default="adapter")
|
||||
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
|
||||
# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
|
||||
# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5
|
||||
default_image_mean = [0.5, 0.5, 0.5]
|
||||
default_image_std = [0.5, 0.5, 0.5]
|
||||
ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
|
||||
ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
|
||||
|
||||
# with proper
|
||||
args = ap.parse_args()
|
||||
|
||||
|
||||
if args.text_only and args.vision_only:
|
||||
print("--text-only and --image-only arguments cannot be specified at the same time.")
|
||||
exit(1)
|
||||
|
||||
if args.use_f32:
|
||||
print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.")
|
||||
|
||||
# output in the same directory as the model if output_dir is None
|
||||
dir_model = args.model_dir
|
||||
|
||||
if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip:
|
||||
vocab = None
|
||||
tokens = None
|
||||
else:
|
||||
with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
tokens = [key for key in vocab]
|
||||
|
||||
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
if args.clip_model_is_vision:
|
||||
v_hparams = config
|
||||
t_hparams = None
|
||||
else:
|
||||
v_hparams = config["vision_config"]
|
||||
t_hparams = None
|
||||
|
||||
# possible data types
|
||||
# ftype == 0 -> float32
|
||||
# ftype == 1 -> float16
|
||||
#
|
||||
# map from ftype to string
|
||||
ftype_str = ["f32", "f16"]
|
||||
|
||||
ftype = 1
|
||||
if args.use_f32:
|
||||
ftype = 0
|
||||
|
||||
vision_config = SiglipVisionConfig(**v_hparams)
|
||||
model = SiglipVisionModel(vision_config)
|
||||
model.load_state_dict(torch.load(os.path.join(dir_model, "glm.clip")))
|
||||
|
||||
fname_middle = None
|
||||
has_text_encoder = False
|
||||
has_vision_encoder = True
|
||||
has_glm_projector = True
|
||||
if args.text_only:
|
||||
fname_middle = "text-"
|
||||
has_vision_encoder = False
|
||||
elif args.llava_projector is not None:
|
||||
fname_middle = "mmproj-"
|
||||
has_text_encoder = False
|
||||
has_glm_projector = True
|
||||
elif args.vision_only:
|
||||
fname_middle = "vision-"
|
||||
has_text_encoder = False
|
||||
else:
|
||||
fname_middle = ""
|
||||
|
||||
output_dir = args.output_dir if args.output_dir is not None else dir_model
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_prefix = os.path.basename(output_dir).replace("ggml_", "")
|
||||
fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf")
|
||||
fout = GGUFWriter(path=fname_out, arch="clip")
|
||||
|
||||
fout.add_bool("clip.has_text_encoder", has_text_encoder)
|
||||
fout.add_bool("clip.has_vision_encoder", has_vision_encoder)
|
||||
fout.add_bool("clip.has_glm_projector", has_glm_projector)
|
||||
fout.add_file_type(ftype)
|
||||
model_name = config["_name_or_path"] if "_name_or_path" in config else os.path.basename(dir_model)
|
||||
fout.add_name(model_name)
|
||||
if has_glm_projector:
|
||||
fout.add_description("image encoder for glm4v")
|
||||
fout.add_string("clip.projector_type", "adapter")
|
||||
else:
|
||||
fout.add_description("two-tower CLIP model")
|
||||
|
||||
if has_text_encoder:
|
||||
assert t_hparams is not None
|
||||
assert tokens is not None
|
||||
# text_model hparams
|
||||
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
|
||||
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"])
|
||||
fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"]))
|
||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"])
|
||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"])
|
||||
fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
|
||||
fout.add_token_list(tokens)
|
||||
|
||||
if has_vision_encoder:
|
||||
# vision_model hparams
|
||||
fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
|
||||
fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
|
||||
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"])
|
||||
fout.add_uint32("clip.vision.projection_dim", 0)
|
||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
|
||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
|
||||
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), v_hparams["num_hidden_layers"])
|
||||
|
||||
image_mean = args.image_mean if args.image_mean is not None else default_image_mean
|
||||
image_std = args.image_std if args.image_std is not None else default_image_std
|
||||
fout.add_array("clip.vision.image_mean", image_mean)
|
||||
fout.add_array("clip.vision.image_std", image_std)
|
||||
|
||||
fout.add_bool("clip.use_gelu", True)
|
||||
|
||||
|
||||
if has_glm_projector:
|
||||
# model.vision_model.encoder.layers.pop(-1) # pyright: ignore[reportAttributeAccessIssue]
|
||||
projector = torch.load(args.llava_projector)
|
||||
for name, data in projector.items():
|
||||
name = get_tensor_name(name)
|
||||
# pw and dw conv ndim==4
|
||||
if data.ndim == 2 or data.ndim == 4:
|
||||
data = data.squeeze().numpy().astype(np.float16)
|
||||
else:
|
||||
data = data.squeeze().numpy().astype(np.float32)
|
||||
if name.startswith("vision."):
|
||||
name=name.replace("vision.","")
|
||||
fout.add_tensor(name, data)
|
||||
print(f"Projector {name} - {data.dtype} - shape = {data.shape}")
|
||||
# print(f"Projector {name} tensors added\n")
|
||||
|
||||
state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue]
|
||||
for name, data in state_dict.items():
|
||||
if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_glm_projector):
|
||||
# we don't need this
|
||||
print(f"skipping parameter: {name}")
|
||||
continue
|
||||
|
||||
name = get_tensor_name(name)
|
||||
data = data.squeeze().numpy()
|
||||
|
||||
n_dims = len(data.shape)
|
||||
|
||||
# ftype == 0 -> float32, ftype == 1 -> float16
|
||||
ftype_cur = 0
|
||||
if n_dims == 4:
|
||||
print(f"tensor {name} is always saved in f16")
|
||||
data = data.astype(np.float16)
|
||||
ftype_cur = 1
|
||||
elif ftype == 1:
|
||||
if name[-7:] == ".weight" and n_dims == 2:
|
||||
# print(" Converting to float16")
|
||||
data = data.astype(np.float16)
|
||||
ftype_cur = 1
|
||||
else:
|
||||
# print(" Converting to float32")
|
||||
data = data.astype(np.float32)
|
||||
ftype_cur = 0
|
||||
else:
|
||||
if data.dtype != np.float32:
|
||||
# print(" Converting to float32")
|
||||
data = data.astype(np.float32)
|
||||
ftype_cur = 0
|
||||
print(f"siglip {name} - {data.dtype} - shape = {data.shape}")
|
||||
# print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
|
||||
fout.add_tensor(name, data)
|
||||
|
||||
|
||||
fout.write_header_to_file()
|
||||
fout.write_kv_data_to_file()
|
||||
fout.write_tensors_to_file()
|
||||
fout.close()
|
||||
|
||||
print("Done. Output file: " + fname_out)
|
||||
33
examples/llava/glmedge-surgery.py
Normal file
33
examples/llava/glmedge-surgery.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoModel
|
||||
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("-m", "--model", help="Path to GLM model")
|
||||
args = ap.parse_args()
|
||||
|
||||
# find the model part that includes the the multimodal projector weights
|
||||
model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True)
|
||||
checkpoint = model.state_dict()
|
||||
|
||||
# get a list of mm tensor names
|
||||
mm_tensors = [k for k, v in checkpoint.items() if k.startswith("vision.adapter.")]
|
||||
|
||||
# store these tensors in a new dictionary and torch.save them
|
||||
projector = {name: checkpoint[name].float() for name in mm_tensors}
|
||||
torch.save(projector, f"{args.model}/glm.projector")
|
||||
|
||||
clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vision.vit.model.vision_model.")]
|
||||
if len(clip_tensors) > 0:
|
||||
clip = {name.replace("vision.vit.model.", ""): checkpoint[name].float() for name in clip_tensors}
|
||||
torch.save(clip, f"{args.model}/glm.clip")
|
||||
|
||||
# added tokens should be removed to be able to convert Mistral models
|
||||
if os.path.exists(f"{args.model}/added_tokens.json"):
|
||||
with open(f"{args.model}/added_tokens.json", "w") as f:
|
||||
f.write("{}\n")
|
||||
|
||||
print("Done!")
|
||||
print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
|
||||
print(f"Also, use {args.model}glm.projector to prepare a glm-encoder.gguf file.")
|
||||
@@ -311,6 +311,20 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||
img_res_v.size = 0;
|
||||
img_res_v.data = nullptr;
|
||||
}
|
||||
else if (clip_is_glm(ctx_clip)){
|
||||
struct clip_image_size * load_image_size = clip_image_size_init();
|
||||
load_image_size->width = img_res_v.data[0].nx;
|
||||
load_image_size->height = img_res_v.data[0].ny;
|
||||
clip_add_load_image_size(ctx_clip, load_image_size);
|
||||
|
||||
bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd);
|
||||
int pos = int(load_image_size->width/clip_patch_size(ctx_clip)/2);
|
||||
*n_img_pos = (pos * pos + 2);
|
||||
if (!encoded){
|
||||
LOG_ERR("Unable to encode image \n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
|
||||
// flat / default llava-1.5 type embedding
|
||||
*n_img_pos = clip_n_patches(ctx_clip);
|
||||
@@ -395,6 +409,9 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
|
||||
if (clip_is_minicpmv(ctx_clip)) {
|
||||
num_max_patches = 10;
|
||||
}
|
||||
if (clip_is_glm(ctx_clip)) {
|
||||
num_max_patches = 1;
|
||||
}
|
||||
float * image_embd;
|
||||
if (clip_is_qwen2vl(ctx_clip)) {
|
||||
// qwen2vl don't split image into chunks, so `num_max_patches` is not needed.
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
project("llama-cli-cmake-pkg" C CXX)
|
||||
set(TARGET llama-cli-cmake-pkg)
|
||||
|
||||
find_package(Llama 0.0.1 REQUIRED)
|
||||
|
||||
# Bake common functionality in with target. Because applications
|
||||
# using the relocatable Llama package should be outside of the
|
||||
# source tree, llama-cli-cmake-pkg pretends the dependencies are built-in.
|
||||
set(_common_path "${CMAKE_CURRENT_LIST_DIR}/../../common")
|
||||
add_library(common OBJECT)
|
||||
file(GLOB _common_files
|
||||
"${_common_path}/*.h"
|
||||
"${_common_path}/*.cpp"
|
||||
)
|
||||
target_sources(common PRIVATE ${_common_files})
|
||||
|
||||
# If the common project was part of "llama-cli-cmake-pkg" the transient
|
||||
# defines would automatically be attached. Because the common func-
|
||||
# tionality is separate, but dependent upon the defines, it must be
|
||||
# explicitly extracted from the "llama" target.
|
||||
#
|
||||
get_target_property(_llama_transient_defines llama
|
||||
INTERFACE_COMPILE_DEFINITIONS)
|
||||
|
||||
target_compile_definitions(common PRIVATE "${_llama_transient_defines}")
|
||||
|
||||
add_executable(${TARGET} ${CMAKE_CURRENT_LIST_DIR}/../main/main.cpp)
|
||||
target_include_directories(${TARGET} PRIVATE ${_common_path})
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
@@ -1,31 +0,0 @@
|
||||
# llama.cpp/example/main-cmake-pkg
|
||||
|
||||
This program builds [llama-cli](../main) using a relocatable CMake package. It serves as an example of using the `find_package()` CMake command to conveniently include [llama.cpp](https://github.com/ggerganov/llama.cpp) in projects which live outside of the source tree.
|
||||
|
||||
## Building
|
||||
|
||||
Because this example is "outside of the source tree", it is important to first build/install llama.cpp using CMake. An example is provided here, but please see the [llama.cpp build instructions](../..) for more detailed build instructions.
|
||||
|
||||
### Considerations
|
||||
|
||||
When hardware acceleration libraries are used (e.g. CUDA, Metal, etc.), CMake must be able to locate the associated CMake package.
|
||||
|
||||
### Build llama.cpp and install to C:\LlamaCPP directory
|
||||
|
||||
```cmd
|
||||
git clone https://github.com/ggerganov/llama.cpp
|
||||
cd llama.cpp
|
||||
cmake -B build -DBUILD_SHARED_LIBS=OFF -G "Visual Studio 17 2022" -A x64
|
||||
cmake --build build --config Release
|
||||
cmake --install build --prefix C:/LlamaCPP
|
||||
```
|
||||
|
||||
### Build llama-cli-cmake-pkg
|
||||
|
||||
|
||||
```cmd
|
||||
cd ..\examples\main-cmake-pkg
|
||||
cmake -B build -DBUILD_SHARED_LIBS=OFF -DCMAKE_PREFIX_PATH="C:/LlamaCPP/lib/cmake/Llama" -G "Visual Studio 17 2022" -A x64
|
||||
cmake --build build --config Release
|
||||
cmake --install build --prefix C:/MyLlamaApp
|
||||
```
|
||||
@@ -310,9 +310,9 @@ These options help improve the performance and memory usage of the LLaMA models.
|
||||
|
||||
### Batch Size
|
||||
|
||||
- `-b N, --batch-size N`: Set the batch size for prompt processing (default: `2048`). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations.
|
||||
- `-ub N`, `--ubatch-size N`: Physical batch size. This is the maximum number of tokens that may be processed at a time. Increasing this value may improve performance during prompt processing, at the expense of higher memory usage. Default: `512`.
|
||||
|
||||
- `-ub N`, `--ubatch-size N`: physical maximum batch size. This is for pipeline parallelization. Default: `512`.
|
||||
- `-b N`, `--batch-size N`: Logical batch size. Increasing this value above the value of the physical batch size may improve prompt processing performance when using multiple GPUs with pipeline parallelism. Default: `2048`.
|
||||
|
||||
### Prompt Caching
|
||||
|
||||
|
||||
@@ -254,7 +254,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab) && !params.use_jinja;
|
||||
if (!llama_model_has_encoder(model)) {
|
||||
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
|
||||
}
|
||||
@@ -264,9 +264,9 @@ int main(int argc, char ** argv) {
|
||||
std::vector<llama_token> embd_inp;
|
||||
|
||||
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
|
||||
common_chat_msg new_msg{role, content};
|
||||
common_chat_msg new_msg{role, content, {}};
|
||||
auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
|
||||
chat_msgs.push_back({role, content});
|
||||
chat_msgs.push_back({role, content, {}});
|
||||
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
||||
return formatted;
|
||||
};
|
||||
@@ -503,12 +503,14 @@ int main(int argc, char ** argv) {
|
||||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
// tokenized antiprompts
|
||||
std::vector<std::vector<llama_token>> antiprompt_ids;
|
||||
// single-token antiprompts
|
||||
std::vector<llama_token> antiprompt_token;
|
||||
|
||||
antiprompt_ids.reserve(params.antiprompt.size());
|
||||
for (const std::string & antiprompt : params.antiprompt) {
|
||||
antiprompt_ids.emplace_back(::common_tokenize(ctx, antiprompt, false, true));
|
||||
auto ids = ::common_tokenize(ctx, antiprompt, false, true);
|
||||
if (ids.size() == 1) {
|
||||
antiprompt_token.push_back(ids[0]);
|
||||
}
|
||||
}
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
@@ -753,14 +755,11 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// check for reverse prompt using special tokens
|
||||
llama_token last_token = common_sampler_last(smpl);
|
||||
for (std::vector<llama_token> ids : antiprompt_ids) {
|
||||
if (ids.size() == 1 && last_token == ids[0]) {
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
}
|
||||
is_antiprompt = true;
|
||||
break;
|
||||
if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) {
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
}
|
||||
is_antiprompt = true;
|
||||
}
|
||||
|
||||
if (is_antiprompt) {
|
||||
|
||||
@@ -3,11 +3,10 @@
|
||||
The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models.
|
||||
|
||||
```bash
|
||||
llama-run granite-code
|
||||
llama-run granite3-moe
|
||||
```
|
||||
|
||||
```bash
|
||||
llama-run -h
|
||||
Description:
|
||||
Runs a llm
|
||||
|
||||
@@ -17,7 +16,7 @@ Usage:
|
||||
Options:
|
||||
-c, --context-size <value>
|
||||
Context size (default: 2048)
|
||||
-n, --ngl <value>
|
||||
-n, -ngl, --ngl <value>
|
||||
Number of GPU layers (default: 0)
|
||||
--temp <value>
|
||||
Temperature (default: 0.8)
|
||||
|
||||
@@ -24,15 +24,16 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "chat-template.hpp"
|
||||
#include "common.h"
|
||||
#include "json.hpp"
|
||||
#include "linenoise.cpp/linenoise.h"
|
||||
#include "llama-cpp.h"
|
||||
#include "chat-template.hpp"
|
||||
#include "log.h"
|
||||
|
||||
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
|
||||
[[noreturn]] static void sigint_handler(int) {
|
||||
printf("\n\033[0m");
|
||||
printf("\n" LOG_COL_DEFAULT);
|
||||
exit(0); // not ideal, but it's the only way to guarantee exit in all cases
|
||||
}
|
||||
#endif
|
||||
@@ -65,6 +66,13 @@ static int printe(const char * fmt, ...) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
static std::string strftime_fmt(const char * fmt, const std::tm & tm) {
|
||||
std::ostringstream oss;
|
||||
oss << std::put_time(&tm, fmt);
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
class Opt {
|
||||
public:
|
||||
int init(int argc, const char ** argv) {
|
||||
@@ -147,7 +155,8 @@ class Opt {
|
||||
if (handle_option_with_value(argc, argv, i, context_size) == 1) {
|
||||
return 1;
|
||||
}
|
||||
} else if (options_parsing && (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0)) {
|
||||
} else if (options_parsing &&
|
||||
(strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "-ngl") == 0 || strcmp(argv[i], "--ngl") == 0)) {
|
||||
if (handle_option_with_value(argc, argv, i, ngl) == 1) {
|
||||
return 1;
|
||||
}
|
||||
@@ -180,6 +189,10 @@ class Opt {
|
||||
}
|
||||
}
|
||||
|
||||
if (model_.empty()){
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -194,7 +207,7 @@ class Opt {
|
||||
"Options:\n"
|
||||
" -c, --context-size <value>\n"
|
||||
" Context size (default: %d)\n"
|
||||
" -n, --ngl <value>\n"
|
||||
" -n, -ngl, --ngl <value>\n"
|
||||
" Number of GPU layers (default: %d)\n"
|
||||
" --temp <value>\n"
|
||||
" Temperature (default: %.1f)\n"
|
||||
@@ -318,6 +331,10 @@ class HttpClient {
|
||||
public:
|
||||
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
|
||||
const bool progress, std::string * response_str = nullptr) {
|
||||
if (std::filesystem::exists(output_file)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string output_file_partial;
|
||||
curl = curl_easy_init();
|
||||
if (!curl) {
|
||||
@@ -345,7 +362,11 @@ class HttpClient {
|
||||
data.file_size = set_resume_point(output_file_partial);
|
||||
set_progress_options(progress, data);
|
||||
set_headers(headers);
|
||||
perform(url);
|
||||
CURLcode res = perform(url);
|
||||
if (res != CURLE_OK){
|
||||
printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res));
|
||||
return 1;
|
||||
}
|
||||
if (!output_file.empty()) {
|
||||
std::filesystem::rename(output_file_partial, output_file);
|
||||
}
|
||||
@@ -410,16 +431,12 @@ class HttpClient {
|
||||
}
|
||||
}
|
||||
|
||||
void perform(const std::string & url) {
|
||||
CURLcode res;
|
||||
CURLcode perform(const std::string & url) {
|
||||
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
|
||||
curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https");
|
||||
curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L);
|
||||
res = curl_easy_perform(curl);
|
||||
if (res != CURLE_OK) {
|
||||
printe("curl_easy_perform() failed: %s\n", curl_easy_strerror(res));
|
||||
}
|
||||
return curl_easy_perform(curl);
|
||||
}
|
||||
|
||||
static std::string human_readable_time(double seconds) {
|
||||
@@ -557,13 +574,14 @@ class LlamaData {
|
||||
}
|
||||
|
||||
sampler = initialize_sampler(opt);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
#ifdef LLAMA_USE_CURL
|
||||
int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
|
||||
const bool progress, std::string * response_str = nullptr) {
|
||||
int download(const std::string & url, const std::string & output_file, const bool progress,
|
||||
const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) {
|
||||
HttpClient http;
|
||||
if (http.init(url, headers, output_file, progress, response_str)) {
|
||||
return 1;
|
||||
@@ -572,48 +590,85 @@ class LlamaData {
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
int download(const std::string &, const std::vector<std::string> &, const std::string &, const bool,
|
||||
int download(const std::string &, const std::string &, const bool, const std::vector<std::string> & = {},
|
||||
std::string * = nullptr) {
|
||||
printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
|
||||
|
||||
return 1;
|
||||
}
|
||||
#endif
|
||||
|
||||
int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
|
||||
// Find the second occurrence of '/' after protocol string
|
||||
size_t pos = model.find('/');
|
||||
pos = model.find('/', pos + 1);
|
||||
if (pos == std::string::npos) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string hfr = model.substr(0, pos);
|
||||
const std::string hff = model.substr(pos + 1);
|
||||
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
|
||||
return download(url, headers, bn, true);
|
||||
}
|
||||
|
||||
int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
|
||||
if (model.find('/') == std::string::npos) {
|
||||
model = "library/" + model;
|
||||
}
|
||||
|
||||
std::string model_tag = "latest";
|
||||
size_t colon_pos = model.find(':');
|
||||
// Helper function to handle model tag extraction and URL construction
|
||||
std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) {
|
||||
std::string model_tag = "latest";
|
||||
const size_t colon_pos = model.find(':');
|
||||
if (colon_pos != std::string::npos) {
|
||||
model_tag = model.substr(colon_pos + 1);
|
||||
model = model.substr(0, colon_pos);
|
||||
}
|
||||
|
||||
std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag;
|
||||
std::string url = base_url + model + "/manifests/" + model_tag;
|
||||
|
||||
return { model, url };
|
||||
}
|
||||
|
||||
// Helper function to download and parse the manifest
|
||||
int download_and_parse_manifest(const std::string & url, const std::vector<std::string> & headers,
|
||||
nlohmann::json & manifest) {
|
||||
std::string manifest_str;
|
||||
const int ret = download(manifest_url, headers, "", false, &manifest_str);
|
||||
int ret = download(url, "", false, headers, &manifest_str);
|
||||
if (ret) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
nlohmann::json manifest = nlohmann::json::parse(manifest_str);
|
||||
std::string layer;
|
||||
manifest = nlohmann::json::parse(manifest_str);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int huggingface_dl(std::string & model, const std::string & bn) {
|
||||
// Find the second occurrence of '/' after protocol string
|
||||
size_t pos = model.find('/');
|
||||
pos = model.find('/', pos + 1);
|
||||
std::string hfr, hff;
|
||||
std::vector<std::string> headers = { "User-Agent: llama-cpp", "Accept: application/json" };
|
||||
std::string url;
|
||||
|
||||
if (pos == std::string::npos) {
|
||||
auto [model_name, manifest_url] = extract_model_and_tag(model, "https://huggingface.co/v2/");
|
||||
hfr = model_name;
|
||||
|
||||
nlohmann::json manifest;
|
||||
int ret = download_and_parse_manifest(manifest_url, headers, manifest);
|
||||
if (ret) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
hff = manifest["ggufFile"]["rfilename"];
|
||||
} else {
|
||||
hfr = model.substr(0, pos);
|
||||
hff = model.substr(pos + 1);
|
||||
}
|
||||
|
||||
url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
|
||||
|
||||
return download(url, bn, true, headers);
|
||||
}
|
||||
|
||||
int ollama_dl(std::string & model, const std::string & bn) {
|
||||
const std::vector<std::string> headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" };
|
||||
if (model.find('/') == std::string::npos) {
|
||||
model = "library/" + model;
|
||||
}
|
||||
|
||||
auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/");
|
||||
nlohmann::json manifest;
|
||||
int ret = download_and_parse_manifest(manifest_url, {}, manifest);
|
||||
if (ret) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string layer;
|
||||
for (const auto & l : manifest["layers"]) {
|
||||
if (l["mediaType"] == "application/vnd.ollama.image.model") {
|
||||
layer = l["digest"];
|
||||
@@ -621,8 +676,67 @@ class LlamaData {
|
||||
}
|
||||
}
|
||||
|
||||
std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
|
||||
return download(blob_url, headers, bn, true);
|
||||
std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer;
|
||||
|
||||
return download(blob_url, bn, true, headers);
|
||||
}
|
||||
|
||||
int github_dl(const std::string & model, const std::string & bn) {
|
||||
std::string repository = model;
|
||||
std::string branch = "main";
|
||||
const size_t at_pos = model.find('@');
|
||||
if (at_pos != std::string::npos) {
|
||||
repository = model.substr(0, at_pos);
|
||||
branch = model.substr(at_pos + 1);
|
||||
}
|
||||
|
||||
const std::vector<std::string> repo_parts = string_split(repository, "/");
|
||||
if (repo_parts.size() < 3) {
|
||||
printe("Invalid GitHub repository format\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string & org = repo_parts[0];
|
||||
const std::string & project = repo_parts[1];
|
||||
std::string url = "https://raw.githubusercontent.com/" + org + "/" + project + "/" + branch;
|
||||
for (size_t i = 2; i < repo_parts.size(); ++i) {
|
||||
url += "/" + repo_parts[i];
|
||||
}
|
||||
|
||||
return download(url, bn, true);
|
||||
}
|
||||
|
||||
int s3_dl(const std::string & model, const std::string & bn) {
|
||||
const size_t slash_pos = model.find('/');
|
||||
if (slash_pos == std::string::npos) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string bucket = model.substr(0, slash_pos);
|
||||
const std::string key = model.substr(slash_pos + 1);
|
||||
const char * access_key = std::getenv("AWS_ACCESS_KEY_ID");
|
||||
const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY");
|
||||
if (!access_key || !secret_key) {
|
||||
printe("AWS credentials not found in environment\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Generate AWS Signature Version 4 headers
|
||||
// (Implementation requires HMAC-SHA256 and date handling)
|
||||
// Get current timestamp
|
||||
const time_t now = time(nullptr);
|
||||
const tm tm = *gmtime(&now);
|
||||
const std::string date = strftime_fmt("%Y%m%d", tm);
|
||||
const std::string datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm);
|
||||
const std::vector<std::string> headers = {
|
||||
"Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date +
|
||||
"/us-east-1/s3/aws4_request",
|
||||
"x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime
|
||||
};
|
||||
|
||||
const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key;
|
||||
|
||||
return download(url, bn, true, headers);
|
||||
}
|
||||
|
||||
std::string basename(const std::string & path) {
|
||||
@@ -634,37 +748,44 @@ class LlamaData {
|
||||
return path.substr(pos + 1);
|
||||
}
|
||||
|
||||
int remove_proto(std::string & model_) {
|
||||
const std::string::size_type pos = model_.find("://");
|
||||
int rm_until_substring(std::string & model_, const std::string & substring) {
|
||||
const std::string::size_type pos = model_.find(substring);
|
||||
if (pos == std::string::npos) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
model_ = model_.substr(pos + 3); // Skip past "://"
|
||||
model_ = model_.substr(pos + substring.size()); // Skip past the substring
|
||||
return 0;
|
||||
}
|
||||
|
||||
int resolve_model(std::string & model_) {
|
||||
int ret = 0;
|
||||
if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) {
|
||||
remove_proto(model_);
|
||||
rm_until_substring(model_, "://");
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
const std::string bn = basename(model_);
|
||||
const std::vector<std::string> headers = { "--header",
|
||||
"Accept: application/vnd.docker.distribution.manifest.v2+json" };
|
||||
if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
|
||||
remove_proto(model_);
|
||||
ret = huggingface_dl(model_, headers, bn);
|
||||
} else if (string_starts_with(model_, "ollama://")) {
|
||||
remove_proto(model_);
|
||||
ret = ollama_dl(model_, headers, bn);
|
||||
} else if (string_starts_with(model_, "https://")) {
|
||||
download(model_, headers, bn, true);
|
||||
} else {
|
||||
ret = ollama_dl(model_, headers, bn);
|
||||
const std::string bn = basename(model_);
|
||||
if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://") ||
|
||||
string_starts_with(model_, "hf.co/")) {
|
||||
rm_until_substring(model_, "hf.co/");
|
||||
rm_until_substring(model_, "://");
|
||||
ret = huggingface_dl(model_, bn);
|
||||
} else if ((string_starts_with(model_, "https://") || string_starts_with(model_, "http://")) &&
|
||||
!string_starts_with(model_, "https://ollama.com/library/")) {
|
||||
ret = download(model_, bn, true);
|
||||
} else if (string_starts_with(model_, "github:") || string_starts_with(model_, "github://")) {
|
||||
rm_until_substring(model_, "github:");
|
||||
rm_until_substring(model_, "://");
|
||||
ret = github_dl(model_, bn);
|
||||
} else if (string_starts_with(model_, "s3://")) {
|
||||
rm_until_substring(model_, "://");
|
||||
ret = s3_dl(model_, bn);
|
||||
} else { // ollama:// or nothing
|
||||
rm_until_substring(model_, "ollama.com/library/");
|
||||
rm_until_substring(model_, "://");
|
||||
ret = ollama_dl(model_, bn);
|
||||
}
|
||||
|
||||
model_ = bn;
|
||||
@@ -727,7 +848,15 @@ static int apply_chat_template(const common_chat_template & tmpl, LlamaData & ll
|
||||
});
|
||||
}
|
||||
try {
|
||||
auto result = tmpl.apply(messages, /* tools= */ json(), append);
|
||||
minja::chat_template_inputs tmpl_inputs;
|
||||
tmpl_inputs.messages = messages;
|
||||
tmpl_inputs.add_generation_prompt = append;
|
||||
|
||||
minja::chat_template_options tmpl_opts;
|
||||
tmpl_opts.use_bos_token = false;
|
||||
tmpl_opts.use_eos_token = false;
|
||||
|
||||
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
||||
llama_data.fmtted.resize(result.size() + 1);
|
||||
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
|
||||
return result.size();
|
||||
@@ -770,7 +899,7 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch &
|
||||
const int n_ctx = llama_n_ctx(ctx.get());
|
||||
const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get());
|
||||
if (n_ctx_used + batch.n_tokens > n_ctx) {
|
||||
printf("\033[0m\n");
|
||||
printf(LOG_COL_DEFAULT "\n");
|
||||
printe("context size exceeded\n");
|
||||
return 1;
|
||||
}
|
||||
@@ -833,7 +962,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
||||
batch = llama_batch_get_one(&new_token_id, 1);
|
||||
}
|
||||
|
||||
printf("\033[0m");
|
||||
printf(LOG_COL_DEFAULT);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -842,7 +971,7 @@ static int read_user_input(std::string & user_input) {
|
||||
#ifdef WIN32
|
||||
printf(
|
||||
"\r%*s"
|
||||
"\r\033[0m%s",
|
||||
"\r" LOG_COL_DEFAULT "%s",
|
||||
get_terminal_width(), " ", prompt_prefix);
|
||||
|
||||
std::getline(std::cin, user_input);
|
||||
@@ -879,7 +1008,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
|
||||
const bool stdout_a_terminal) {
|
||||
// Set response color
|
||||
if (stdout_a_terminal) {
|
||||
printf("\033[33m");
|
||||
printf(LOG_COL_YELLOW);
|
||||
}
|
||||
|
||||
if (generate(llama_data, prompt, response)) {
|
||||
@@ -888,7 +1017,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
|
||||
}
|
||||
|
||||
// End response with color reset and newline
|
||||
printf("\n%s", stdout_a_terminal ? "\033[0m" : "");
|
||||
printf("\n%s", stdout_a_terminal ? LOG_COL_DEFAULT : "");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ The project is under active development, and we are [looking for feedback and co
|
||||
| `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') |
|
||||
| `--grammar-file FNAME` | file to read grammar from |
|
||||
| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead |
|
||||
| `--jinja` | Enable experimental Jinja templating engine (needed for tool use) |
|
||||
| `--jinja` | Enable experimental Jinja templating engine (required for tool use) |
|
||||
|
||||
**Example-specific params**
|
||||
|
||||
@@ -236,9 +236,13 @@ npm i
|
||||
# to run the dev server
|
||||
npm run dev
|
||||
|
||||
# to build the public/index.html
|
||||
# to build the public/index.html.gz
|
||||
npm run build
|
||||
```
|
||||
After `public/index.html.gz` has been generated we need to generate the c++
|
||||
headers (like build/examples/server/index.html.gz.hpp) that will be included
|
||||
by server.cpp. This is done by building `llama-server` as described in the
|
||||
[build](#build) section above.
|
||||
|
||||
NOTE: if you are using the vite dev server, you can change the API base URL to llama.cpp. To do that, run this code snippet in browser's console:
|
||||
|
||||
@@ -456,7 +460,7 @@ These words will not be included in the completion, so make sure to add them to
|
||||
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
|
||||
|
||||
- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has a nested array `top_logprobs`. It contains at **maximum** `n_probs` elements:
|
||||
```json
|
||||
```
|
||||
{
|
||||
"content": "<the generated completion text>",
|
||||
"tokens": [ generated token ids if requested ],
|
||||
@@ -557,7 +561,7 @@ If `with_pieces` is `true`:
|
||||
```
|
||||
|
||||
With input 'á' (utf8 hex: C3 A1) on tinyllama/stories260k
|
||||
```json
|
||||
```
|
||||
{
|
||||
"tokens": [
|
||||
{"id": 198, "piece": [195]}, // hex C3
|
||||
@@ -572,6 +576,18 @@ With input 'á' (utf8 hex: C3 A1) on tinyllama/stories260k
|
||||
|
||||
`tokens`: Set the tokens to detokenize.
|
||||
|
||||
### POST `/apply-template`: Apply chat template to a conversation
|
||||
|
||||
Uses the server's prompt template formatting functionality to convert chat messages to a single string expected by a chat model as input, but does not perform inference. Instead, the prompt string is returned in the `prompt` field of the JSON response. The prompt can then be modified as desired (for example, to insert "Sure!" at the beginning of the model's response) before sending to `/completion` to generate the chat response.
|
||||
|
||||
*Options:*
|
||||
|
||||
`messages`: (Required) Chat turns in the same format as `/v1/chat/completions`.
|
||||
|
||||
**Response format**
|
||||
|
||||
Returns a JSON object with a field `prompt` containing a string of the input messages formatted according to the model's chat template format.
|
||||
|
||||
### POST `/embedding`: Generate embedding of a given text
|
||||
|
||||
> [!IMPORTANT]
|
||||
@@ -764,7 +780,7 @@ Same as the `/v1/embeddings` endpoint.
|
||||
|
||||
**Response format**
|
||||
|
||||
```json
|
||||
```
|
||||
[
|
||||
{
|
||||
"index": 0,
|
||||
@@ -1053,7 +1069,7 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte
|
||||
|
||||
*Options:*
|
||||
|
||||
See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). While some OpenAI-specific features such as function calling aren't supported, llama.cpp `/completion`-specific features such as `mirostat` are supported.
|
||||
See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). llama.cpp `/completion`-specific features such as `mirostat` are also supported.
|
||||
|
||||
The `response_format` parameter supports both plain JSON output (e.g. `{"type": "json_object"}`) and schema-constrained JSON (e.g. `{"type": "json_object", "schema": {"type": "string", "minLength": 10, "maxLength": 100}}` or `{"type": "json_schema", "schema": {"properties": { "name": { "title": "Name", "type": "string" }, "date": { "title": "Date", "type": "string" }, "participants": { "items": {"type: "string" }, "title": "Participants", "type": "string" } } } }`), similar to other OpenAI-inspired API providers.
|
||||
|
||||
@@ -1101,6 +1117,184 @@ curl http://localhost:8080/v1/chat/completions \
|
||||
}'
|
||||
```
|
||||
|
||||
*Tool call support*
|
||||
|
||||
[Function calling](https://platform.openai.com/docs/guides/function-calling) is supported for all models (see https://github.com/ggerganov/llama.cpp/pull/9639):
|
||||
|
||||
- Requires `--jinja` flag
|
||||
- Native tool call formats supported:
|
||||
- Llama 3.1 / 3.3 (including builtin tools support - tool names for `wolfram_alpha`, `web_search` / `brave_search`, `code_interpreter`), Llama 3.2
|
||||
- Functionary v3.1 / v3.2
|
||||
- Hermes 2/3, Qwen 2.5
|
||||
- Mistral Nemo
|
||||
- Firefunction v2
|
||||
- Command R7B
|
||||
- DeepSeek R1 (WIP / seems reluctant to call any tools?)
|
||||
|
||||
<details>
|
||||
<summary>Show some common templates and which format handler they use</summary>
|
||||
|
||||
| Template | Format |
|
||||
|----------|--------|
|
||||
| CohereForAI-c4ai-command-r-plus-default.jinja | generic tool calls |
|
||||
| CohereForAI-c4ai-command-r-plus-rag.jinja | generic tool calls |
|
||||
| CohereForAI-c4ai-command-r-plus-tool_use.jinja | generic tool calls |
|
||||
| MiniMaxAI-MiniMax-Text-01.jinja | generic tool calls |
|
||||
| NexaAIDev-Octopus-v2.jinja | generic tool calls |
|
||||
| NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | generic tool calls |
|
||||
| NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | hermes 2 pro tool calls |
|
||||
| NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | generic tool calls |
|
||||
| NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | hermes 2 pro tool calls |
|
||||
| NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | generic tool calls |
|
||||
| NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | hermes 2 pro tool calls |
|
||||
| OrionStarAI-Orion-14B-Chat.jinja | generic tool calls |
|
||||
| Qwen-QwQ-32B-Preview.jinja | hermes 2 pro tool calls |
|
||||
| Qwen-Qwen2-7B-Instruct.jinja | generic tool calls |
|
||||
| Qwen-Qwen2-VL-7B-Instruct.jinja | generic tool calls |
|
||||
| Qwen-Qwen2.5-7B-Instruct.jinja | hermes 2 pro tool calls |
|
||||
| Qwen-Qwen2.5-Math-7B-Instruct.jinja | hermes 2 pro tool calls |
|
||||
| TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | generic tool calls |
|
||||
| abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | generic tool calls |
|
||||
| bofenghuang-vigogne-2-70b-chat.jinja | generic tool calls |
|
||||
| databricks-dbrx-instruct.jinja | generic tool calls |
|
||||
| deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | generic tool calls |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | deepseek r1 tool calls |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | deepseek r1 tool calls |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | deepseek r1 tool calls |
|
||||
| deepseek-ai-DeepSeek-V2.5.jinja | deepseek r1 tool calls |
|
||||
| deepseek-ai-deepseek-coder-33b-instruct.jinja | generic tool calls |
|
||||
| google-gemma-2-2b-it.jinja | generic tool calls |
|
||||
| google-gemma-7b-it.jinja | generic tool calls |
|
||||
| indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | generic tool calls |
|
||||
| mattshumer-Reflection-Llama-3.1-70B.jinja | generic tool calls |
|
||||
| meetkai-functionary-medium-v3.2.jinja | functionary v3.2 tool calls |
|
||||
| meta-llama-Llama-3.1-8B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) |
|
||||
| meta-llama-Llama-3.2-3B-Instruct.jinja | llama 3.x tool calls |
|
||||
| meta-llama-Llama-3.3-70B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) |
|
||||
| meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) |
|
||||
| microsoft-Phi-3-medium-4k-instruct.jinja | generic tool calls |
|
||||
| microsoft-Phi-3-mini-4k-instruct.jinja | generic tool calls |
|
||||
| microsoft-Phi-3-small-8k-instruct.jinja | generic tool calls |
|
||||
| microsoft-Phi-3.5-mini-instruct.jinja | generic tool calls |
|
||||
| microsoft-Phi-3.5-vision-instruct.jinja | generic tool calls |
|
||||
| mistralai-Mistral-7B-Instruct-v0.2.jinja | generic tool calls |
|
||||
| mistralai-Mistral-Large-Instruct-2407.jinja | mistral nemo tool calls |
|
||||
| mistralai-Mistral-Large-Instruct-2411.jinja | generic tool calls |
|
||||
| mistralai-Mistral-Nemo-Instruct-2407.jinja | mistral nemo tool calls |
|
||||
| mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | generic tool calls |
|
||||
| mlabonne-AlphaMonarch-7B.jinja | generic tool calls |
|
||||
| nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | llama 3.x tool calls (w/ builtin tools) |
|
||||
| openchat-openchat-3.5-0106.jinja | generic tool calls |
|
||||
| teknium-OpenHermes-2.5-Mistral-7B.jinja | generic tool calls |
|
||||
|
||||
This table can be generated with:
|
||||
|
||||
```bash
|
||||
./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
|
||||
|
||||
</details>
|
||||
|
||||
- Generic tool call is supported when the template isn't recognized by native format handlers (you'll see `Chat format: Generic` in the logs).
|
||||
- Use `--chat-template-file` to override the template when appropriate (see examples below)
|
||||
- Generic support may consume more tokens and be less efficient than a model's native format.
|
||||
|
||||
- Run with:
|
||||
|
||||
```shell
|
||||
# Native support:
|
||||
llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M
|
||||
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L
|
||||
llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
|
||||
llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
|
||||
|
||||
# Native support requires the right template for these GGUFs:
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
|
||||
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use )
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \
|
||||
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \
|
||||
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \
|
||||
--chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
|
||||
|
||||
# Generic format support
|
||||
llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0
|
||||
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0
|
||||
llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K
|
||||
```
|
||||
|
||||
- Test in CLI:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/v1/chat/completions -d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"tools": [
|
||||
{
|
||||
"type":"function",
|
||||
"function":{
|
||||
"name":"get_current_weather",
|
||||
"description":"Get the current weather in a given location",
|
||||
"parameters":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"location":{
|
||||
"type":"string",
|
||||
"description":"The city and state, e.g. San Francisco, CA"
|
||||
}
|
||||
},
|
||||
"required":["location"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Istanbul?."
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Show output</summary>
|
||||
|
||||
```json
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "tool",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "python",
|
||||
"arguments": "{\"code\":\" \\nprint(\\\"Hello, World!\\\")\"}"
|
||||
}
|
||||
],
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1727287211,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"completion_tokens": 16,
|
||||
"prompt_tokens": 44,
|
||||
"total_tokens": 60
|
||||
},
|
||||
"id": "chatcmpl-Htbgh9feMmGM0LEH2hmQvwsCxq3c6Ni8"
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### POST `/v1/embeddings`: OpenAI-compatible embeddings API
|
||||
|
||||
This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm.
|
||||
|
||||
Binary file not shown.
@@ -14,7 +14,7 @@
|
||||
// mime type for sending response
|
||||
#define MIMETYPE_JSON "application/json; charset=utf-8"
|
||||
|
||||
// auto generated files (update with ./deps.sh)
|
||||
// auto generated files (see README.md for details)
|
||||
#include "index.html.gz.hpp"
|
||||
#include "loading.html.hpp"
|
||||
|
||||
@@ -113,10 +113,11 @@ struct slot_params {
|
||||
struct common_params_speculative speculative;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
||||
json to_json() const {
|
||||
std::vector<std::string> samplers;
|
||||
@@ -130,6 +131,11 @@ struct slot_params {
|
||||
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
|
||||
}
|
||||
|
||||
std::vector<std::string> grammar_trigger_words;
|
||||
for (const auto & trigger : sampling.grammar_trigger_words) {
|
||||
grammar_trigger_words.push_back(trigger.word);
|
||||
}
|
||||
|
||||
return json {
|
||||
{"n_predict", n_predict}, // Server configured n_predict
|
||||
{"seed", sampling.seed},
|
||||
@@ -164,6 +170,9 @@ struct slot_params {
|
||||
{"n_probs", sampling.n_probs},
|
||||
{"min_keep", sampling.min_keep},
|
||||
{"grammar", sampling.grammar},
|
||||
{"grammar_trigger_words", grammar_trigger_words},
|
||||
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
|
||||
{"preserved_tokens", sampling.preserved_tokens},
|
||||
{"samplers", samplers},
|
||||
{"speculative.n_max", speculative.n_max},
|
||||
{"speculative.n_min", speculative.n_min},
|
||||
@@ -325,12 +334,64 @@ struct server_task {
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||
LOG_DBG("JSON schema: %s\n", schema.dump(2).c_str());
|
||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||
LOG_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||
}
|
||||
} else {
|
||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||
LOG_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
|
||||
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
|
||||
LOG_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
|
||||
}
|
||||
|
||||
{
|
||||
auto it = data.find("chat_format");
|
||||
if (it != data.end()) {
|
||||
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
|
||||
LOG_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
|
||||
} else {
|
||||
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const auto grammar_triggers = data.find("grammar_triggers");
|
||||
if (grammar_triggers != data.end()) {
|
||||
for (const auto & t : *grammar_triggers) {
|
||||
common_grammar_trigger trigger;
|
||||
trigger.word = t.at("word");
|
||||
trigger.at_start = t.at("at_start");
|
||||
|
||||
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
||||
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||
params.sampling.preserved_tokens.insert(ids[0]);
|
||||
continue;
|
||||
}
|
||||
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
||||
params.sampling.grammar_trigger_words.push_back(trigger);
|
||||
}
|
||||
}
|
||||
const auto preserved_tokens = data.find("preserved_tokens");
|
||||
if (preserved_tokens != data.end()) {
|
||||
for (const auto & t : *preserved_tokens) {
|
||||
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
LOG_DBG("Preserved token: %d\n", ids[0]);
|
||||
params.sampling.preserved_tokens.insert(ids[0]);
|
||||
} else {
|
||||
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
|
||||
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (params.sampling.grammar_lazy) {
|
||||
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
@@ -382,22 +443,12 @@ struct server_task {
|
||||
}
|
||||
|
||||
{
|
||||
const auto & samplers = data.find("samplers");
|
||||
const auto samplers = data.find("samplers");
|
||||
if (samplers != data.end()) {
|
||||
if (samplers->is_array()) {
|
||||
std::vector<std::string> sampler_names;
|
||||
for (const auto & name : *samplers) {
|
||||
if (name.is_string()) {
|
||||
sampler_names.emplace_back(name);
|
||||
}
|
||||
}
|
||||
params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
|
||||
params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
|
||||
} else if (samplers->is_string()){
|
||||
std::string sampler_string;
|
||||
for (const auto & name : *samplers) {
|
||||
sampler_string += name;
|
||||
}
|
||||
params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
|
||||
params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
|
||||
}
|
||||
} else {
|
||||
params.sampling.samplers = defaults.sampling.samplers;
|
||||
@@ -544,7 +595,7 @@ struct completion_token_output {
|
||||
struct server_task_result_cmpl_final : server_task_result {
|
||||
int index = 0;
|
||||
|
||||
std::string content;
|
||||
std::string content;
|
||||
llama_tokens tokens;
|
||||
|
||||
bool stream;
|
||||
@@ -566,10 +617,11 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
slot_params generation_params;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
@@ -663,18 +715,44 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
|
||||
json to_json_oaicompat_chat() {
|
||||
std::string finish_reason = "length";
|
||||
common_chat_msg msg;
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = "stop";
|
||||
LOG_DBG("Parsing chat message: %s\n", content.c_str());
|
||||
msg = common_chat_parse(content, oaicompat_chat_format);
|
||||
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
} else {
|
||||
msg.content = content;
|
||||
}
|
||||
|
||||
json choice = json{
|
||||
json tool_calls;
|
||||
if (!msg.tool_calls.empty()) {
|
||||
tool_calls = json::array();
|
||||
for (const auto & tc : msg.tool_calls) {
|
||||
tool_calls.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tc.name},
|
||||
{"arguments", tc.arguments},
|
||||
}},
|
||||
{"id", tc.id},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
json message {
|
||||
{"content", msg.content},
|
||||
{"tool_calls", tool_calls},
|
||||
{"role", "assistant"},
|
||||
};
|
||||
if (!msg.tool_plan.empty()) {
|
||||
message["tool_plan"] = msg.tool_plan;
|
||||
}
|
||||
|
||||
json choice {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json {
|
||||
{"content", content},
|
||||
{"role", "assistant"}
|
||||
}
|
||||
}};
|
||||
{"message", message},
|
||||
};
|
||||
|
||||
if (!stream && probs_output.size() > 0) {
|
||||
choice["logprobs"] = json{
|
||||
@@ -716,7 +794,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
finish_reason = "stop";
|
||||
}
|
||||
|
||||
json choice = json{
|
||||
json choice = json {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}
|
||||
@@ -1191,6 +1269,8 @@ struct server_slot {
|
||||
|
||||
llama_token sampled;
|
||||
|
||||
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
||||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
|
||||
@@ -1427,6 +1507,10 @@ struct server_queue {
|
||||
int post(server_task task, bool front = false) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
GGML_ASSERT(task.id != -1);
|
||||
// if this is cancel task make sure to clean up pending tasks
|
||||
if (task.type == SERVER_TASK_TYPE_CANCEL) {
|
||||
cleanup_pending_task(task.id_target);
|
||||
}
|
||||
QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
|
||||
if (front) {
|
||||
queue_tasks.push_front(std::move(task));
|
||||
@@ -1444,6 +1528,10 @@ struct server_queue {
|
||||
if (task.id == -1) {
|
||||
task.id = id++;
|
||||
}
|
||||
// if this is cancel task make sure to clean up pending tasks
|
||||
if (task.type == SERVER_TASK_TYPE_CANCEL) {
|
||||
cleanup_pending_task(task.id_target);
|
||||
}
|
||||
QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
|
||||
if (front) {
|
||||
queue_tasks.push_front(std::move(task));
|
||||
@@ -1544,6 +1632,20 @@ struct server_queue {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void cleanup_pending_task(int id_target) {
|
||||
// no need lock because this is called exclusively by post()
|
||||
auto rm_func = [id_target](const server_task & task) {
|
||||
return task.id_target == id_target;
|
||||
};
|
||||
queue_tasks.erase(
|
||||
std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
|
||||
queue_tasks.end());
|
||||
queue_tasks_deferred.erase(
|
||||
std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
|
||||
queue_tasks_deferred.end());
|
||||
}
|
||||
};
|
||||
|
||||
struct server_response {
|
||||
@@ -1579,6 +1681,12 @@ struct server_response {
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
waiting_task_ids.erase(id_task);
|
||||
// make sure to clean up all pending results
|
||||
queue_results.erase(
|
||||
std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
|
||||
return res->id == id_task;
|
||||
}),
|
||||
queue_results.end());
|
||||
}
|
||||
|
||||
void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
|
||||
@@ -1598,7 +1706,7 @@ struct server_response {
|
||||
return !queue_results.empty();
|
||||
});
|
||||
|
||||
for (int i = 0; i < (int) queue_results.size(); i++) {
|
||||
for (size_t i = 0; i < queue_results.size(); i++) {
|
||||
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
||||
server_task_result_ptr res = std::move(queue_results[i]);
|
||||
queue_results.erase(queue_results.begin() + i);
|
||||
@@ -1615,12 +1723,6 @@ struct server_response {
|
||||
server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
bool cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{
|
||||
return !queue_results.empty();
|
||||
});
|
||||
if (!cr_res) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int) queue_results.size(); i++) {
|
||||
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
||||
@@ -1629,6 +1731,11 @@ struct server_response {
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
|
||||
if (cr_res == std::cv_status::timeout) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// should never reach here
|
||||
@@ -1777,7 +1884,12 @@ struct server_context {
|
||||
llama_init_dft.context.reset();
|
||||
}
|
||||
|
||||
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
|
||||
if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
|
||||
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
||||
chat_templates = common_chat_templates_from_model(model, "chatml");
|
||||
} else {
|
||||
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
|
||||
}
|
||||
GGML_ASSERT(chat_templates.template_default.get() != nullptr);
|
||||
|
||||
return true;
|
||||
@@ -1788,17 +1900,16 @@ struct server_context {
|
||||
|
||||
if (use_jinja) {
|
||||
auto templates = common_chat_templates_from_model(model, "");
|
||||
common_chat_inputs inputs;
|
||||
inputs.messages = json::array({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}});
|
||||
GGML_ASSERT(templates.template_default);
|
||||
try {
|
||||
templates.template_default->apply({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}}, json(), true);
|
||||
common_chat_params_init(*templates.template_default, inputs);
|
||||
if (templates.template_tool_use) {
|
||||
templates.template_tool_use->apply({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}}, json(), true);
|
||||
common_chat_params_init(*templates.template_tool_use, inputs);
|
||||
}
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
@@ -2248,11 +2359,11 @@ struct server_context {
|
||||
res->id_slot = slot.id;
|
||||
|
||||
res->index = slot.index;
|
||||
res->content = slot.generated_text;
|
||||
res->tokens = slot.generated_tokens;
|
||||
res->content = std::move(slot.generated_text);
|
||||
res->tokens = std::move(slot.generated_tokens);
|
||||
res->timings = slot.get_timings();
|
||||
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
||||
res->response_fields = slot.params.response_fields;
|
||||
res->response_fields = std::move(slot.params.response_fields);
|
||||
|
||||
res->truncated = slot.truncated;
|
||||
res->n_decoded = slot.n_decoded;
|
||||
@@ -2263,12 +2374,12 @@ struct server_context {
|
||||
res->stop = slot.stop;
|
||||
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||
|
||||
res->verbose = slot.params.verbose;
|
||||
res->stream = slot.params.stream;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
|
||||
res->verbose = slot.params.verbose;
|
||||
res->stream = slot.params.stream;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
|
||||
@@ -2376,8 +2487,8 @@ struct server_context {
|
||||
|
||||
server_task task(SERVER_TASK_TYPE_CANCEL);
|
||||
task.id_target = id_task;
|
||||
cancel_tasks.push_back(task);
|
||||
queue_results.remove_waiting_task_id(id_task);
|
||||
cancel_tasks.push_back(task);
|
||||
}
|
||||
// push to beginning of the queue, so it has highest priority
|
||||
queue_tasks.post(cancel_tasks, true);
|
||||
@@ -2746,6 +2857,10 @@ struct server_context {
|
||||
// track if given slot can be batched with slots already in the batch
|
||||
server_slot * slot_batched = nullptr;
|
||||
|
||||
auto accept_special_token = [&](server_slot & slot, llama_token token) {
|
||||
return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
|
||||
};
|
||||
|
||||
// frist, add sampled tokens from any ongoing sequences
|
||||
for (auto & slot : slots) {
|
||||
if (slot.state != SLOT_STATE_GENERATING) {
|
||||
@@ -3109,7 +3224,7 @@ struct server_context {
|
||||
|
||||
completion_token_output result;
|
||||
result.tok = id;
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
||||
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
@@ -3198,7 +3313,7 @@ struct server_context {
|
||||
completion_token_output result;
|
||||
|
||||
result.tok = ids[i];
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
||||
result.prob = 1.0f; // set later
|
||||
|
||||
// TODO: set result.probs
|
||||
@@ -3238,6 +3353,8 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
|
||||
return;
|
||||
}
|
||||
|
||||
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
|
||||
|
||||
LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
|
||||
|
||||
LOG_DBG("request: %s\n", req.body.c_str());
|
||||
@@ -3324,9 +3441,13 @@ int main(int argc, char ** argv) {
|
||||
message = "Unknown Exception";
|
||||
}
|
||||
|
||||
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
|
||||
LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
|
||||
res_error(res, formatted_error);
|
||||
try {
|
||||
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
|
||||
LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
|
||||
res_error(res, formatted_error);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str());
|
||||
}
|
||||
});
|
||||
|
||||
svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
|
||||
@@ -3548,11 +3669,11 @@ int main(int argc, char ** argv) {
|
||||
{"value", (uint64_t) res_metrics->kv_cache_tokens_count}
|
||||
},{
|
||||
{"name", "requests_processing"},
|
||||
{"help", "Number of request processing."},
|
||||
{"help", "Number of requests processing."},
|
||||
{"value", (uint64_t) res_metrics->n_processing_slots}
|
||||
},{
|
||||
{"name", "requests_deferred"},
|
||||
{"help", "Number of request deferred."},
|
||||
{"help", "Number of requests deferred."},
|
||||
{"value", (uint64_t) res_metrics->n_tasks_deferred}
|
||||
}}}
|
||||
};
|
||||
@@ -3695,6 +3816,8 @@ int main(int argc, char ** argv) {
|
||||
{ "total_slots", ctx_server.params_base.n_parallel },
|
||||
{ "model_path", ctx_server.params_base.model },
|
||||
{ "chat_template", ctx_server.chat_templates.template_default->source() },
|
||||
{ "bos_token", ctx_server.chat_templates.template_default->bos_token() },
|
||||
{ "eos_token", ctx_server.chat_templates.template_default->eos_token() },
|
||||
{ "build_info", build_info },
|
||||
};
|
||||
if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
|
||||
@@ -3736,7 +3859,9 @@ int main(int argc, char ** argv) {
|
||||
std::vector<server_task> tasks;
|
||||
|
||||
try {
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
|
||||
const auto & prompt = data.at("prompt");
|
||||
LOG_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
|
||||
tasks.reserve(tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
@@ -3752,8 +3877,8 @@ int main(int argc, char ** argv) {
|
||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
tasks.push_back(task);
|
||||
@@ -3922,14 +4047,14 @@ int main(int argc, char ** argv) {
|
||||
};
|
||||
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
LOG_DBG("request: %s\n", req.body.c_str());
|
||||
if (ctx_server.params_base.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
auto body = json::parse(req.body);
|
||||
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
|
||||
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates);
|
||||
|
||||
return handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
@@ -3939,6 +4064,13 @@ int main(int argc, char ** argv) {
|
||||
OAICOMPAT_TYPE_CHAT);
|
||||
};
|
||||
|
||||
// same with handle_chat_completions, but without inference part
|
||||
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
auto body = json::parse(req.body);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates);
|
||||
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
|
||||
};
|
||||
|
||||
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
json models = {
|
||||
{"object", "list"},
|
||||
@@ -4273,6 +4405,7 @@ int main(int argc, char ** argv) {
|
||||
svr->Post("/v1/reranking", handle_rerank);
|
||||
svr->Post("/tokenize", handle_tokenize);
|
||||
svr->Post("/detokenize", handle_detokenize);
|
||||
svr->Post("/apply-template", handle_apply_template);
|
||||
// LoRA adapters hotswap
|
||||
svr->Get ("/lora-adapters", handle_lora_adapters_list);
|
||||
svr->Post("/lora-adapters", handle_lora_adapters_apply);
|
||||
@@ -4338,24 +4471,18 @@ int main(int argc, char ** argv) {
|
||||
|
||||
LOG_INF("%s: model loaded\n", __func__);
|
||||
|
||||
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
||||
if (params.chat_template.empty()) {
|
||||
if (!ctx_server.validate_builtin_chat_template(params.use_jinja)) {
|
||||
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
||||
params.chat_template = "chatml";
|
||||
}
|
||||
}
|
||||
|
||||
// print sample chat example to make it clear which template is used
|
||||
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
||||
ctx_server.chat_templates.template_default->source().c_str(),
|
||||
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
|
||||
|
||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
|
||||
ctx_server.process_single_task(task);
|
||||
});
|
||||
|
||||
ctx_server.queue_tasks.on_update_slots(std::bind(
|
||||
&server_context::update_slots, &ctx_server));
|
||||
ctx_server.queue_tasks.on_update_slots([&ctx_server]() {
|
||||
ctx_server.update_slots();
|
||||
});
|
||||
|
||||
shutdown_handler = [&](int) {
|
||||
ctx_server.queue_tasks.terminate();
|
||||
|
||||
@@ -31,8 +31,9 @@ It's possible to override some scenario steps values with environment variables:
|
||||
| `LLAMA_SERVER_BIN_PATH` | to change the server binary path, default: `../../../build/bin/llama-server` |
|
||||
| `DEBUG` | to enable steps and server verbose mode `--verbose` |
|
||||
| `N_GPU_LAYERS` | number of model layers to offload to VRAM `-ngl --n-gpu-layers` |
|
||||
| `LLAMA_CACHE` | by default server tests re-download models to the `tmp` subfolder. Set this to your cache (e.g. `$HOME/Library/Caches/llama.cpp` on Mac or `$HOME/.cache/llama.cpp` on Unix) to avoid this |
|
||||
|
||||
To run slow tests:
|
||||
To run slow tests (will download many models, make sure to set `LLAMA_CACHE` if needed):
|
||||
|
||||
```shell
|
||||
SLOW_TESTS=1 ./tests.sh
|
||||
@@ -44,10 +45,16 @@ To run with stdout/stderr display in real time (verbose output, but useful for d
|
||||
DEBUG=1 ./tests.sh -s -v -x
|
||||
```
|
||||
|
||||
To run single test unit:
|
||||
To run all the tests in a file:
|
||||
|
||||
```shell
|
||||
./tests.sh unit/test_{name of test case here}.py -v -x
|
||||
./tests.sh unit/test_chat_completion.py.py -v -x
|
||||
```
|
||||
|
||||
To run a single test:
|
||||
|
||||
```shell
|
||||
./tests.sh unit/test_chat_completion.py::test_invalid_chat_completion_req
|
||||
```
|
||||
|
||||
Hint: You can compile and run test in single command, useful for local developement:
|
||||
|
||||
4
examples/server/tests/pytest.ini
Normal file
4
examples/server/tests/pytest.ini
Normal file
@@ -0,0 +1,4 @@
|
||||
[pytest]
|
||||
markers =
|
||||
slow: marks tests as slow (deselect with '-m "not slow"')
|
||||
serial
|
||||
@@ -6,9 +6,18 @@ cd $SCRIPT_DIR
|
||||
|
||||
set -eu
|
||||
|
||||
if [[ "${SLOW_TESTS:-0}" == 1 ]]; then
|
||||
# Slow tests for tool calls need quite a few models ahead of time to avoid timing out.
|
||||
python $SCRIPT_DIR/../../../scripts/fetch_server_test_models.py
|
||||
fi
|
||||
|
||||
if [ $# -lt 1 ]
|
||||
then
|
||||
pytest -v -x
|
||||
if [[ "${SLOW_TESTS:-0}" == 1 ]]; then
|
||||
pytest -v -x
|
||||
else
|
||||
pytest -v -x -m "not slow"
|
||||
fi
|
||||
else
|
||||
pytest "$@"
|
||||
fi
|
||||
|
||||
@@ -2,7 +2,7 @@ import pytest
|
||||
from openai import OpenAI
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
server: ServerProcess
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
@@ -13,9 +13,12 @@ def create_server():
|
||||
@pytest.mark.parametrize(
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
|
||||
[
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True, None),
|
||||
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
|
||||
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
|
||||
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
|
||||
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
|
||||
]
|
||||
@@ -121,6 +124,21 @@ def test_chat_template():
|
||||
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
|
||||
|
||||
def test_apply_chat_template():
|
||||
global server
|
||||
server.chat_template = "command-r"
|
||||
server.start()
|
||||
res = server.make_request("POST", "/apply-template", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a test."},
|
||||
{"role": "user", "content":"Hi there"},
|
||||
]
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "prompt" in res.body
|
||||
assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
|
||||
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
|
||||
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
|
||||
|
||||
@@ -87,7 +87,7 @@ def test_completion_stream_vs_non_stream():
|
||||
assert content_stream == res_non_stream.body["content"]
|
||||
|
||||
|
||||
def test_completion_stream_with_openai_library():
|
||||
def test_completion_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
@@ -102,7 +102,7 @@ def test_completion_stream_with_openai_library():
|
||||
assert match_regex("(going|bed)+", res.choices[0].text)
|
||||
|
||||
|
||||
def test_completion_with_openai_library():
|
||||
def test_completion_stream_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
|
||||
418
examples/server/tests/unit/test_tool_call.py
Normal file
418
examples/server/tests/unit/test_tool_call.py
Normal file
@@ -0,0 +1,418 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
TIMEOUT_SERVER_START = 15*60
|
||||
TIMEOUT_HTTP_REQUEST = 60
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.model_alias = "tinyllama-2-tool-call"
|
||||
server.server_port = 8081
|
||||
|
||||
|
||||
TEST_TOOL = {
|
||||
"type":"function",
|
||||
"function": {
|
||||
"name": "test",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean", "const": True},
|
||||
},
|
||||
"required": ["success"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PYTHON_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "python",
|
||||
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code to run in the ipython interpreter."
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WEATHER_TOOL = {
|
||||
"type":"function",
|
||||
"function":{
|
||||
"name":"get_current_weather",
|
||||
"description":"Get the current weather in a given location",
|
||||
"parameters":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"location":{
|
||||
"type":"string",
|
||||
"description":"The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'"
|
||||
}
|
||||
},
|
||||
"required":["location"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
|
||||
global server
|
||||
n_predict = 512
|
||||
# server = ServerPreset.stories15m_moe()
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "Write an example"},
|
||||
],
|
||||
"tool_choice": "required",
|
||||
"tools": [tool],
|
||||
"parallel_tool_calls": False,
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
})
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
assert isinstance(actual_arguments, str)
|
||||
if argument_key is not None:
|
||||
actual_arguments = json.loads(actual_arguments)
|
||||
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
("google-gemma-2-2b-it", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
||||
])
|
||||
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
|
||||
do_test_completion_with_required_tool_tiny(template_name, tool, argument_key)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
||||
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
|
||||
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
|
||||
("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
|
||||
("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
|
||||
])
|
||||
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
|
||||
do_test_completion_with_required_tool_tiny(template_name, tool, argument_key)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
|
||||
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
|
||||
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
# TODO: fix these
|
||||
# (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
# (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
global server
|
||||
n_predict = 512
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = n_predict
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "Write an example"},
|
||||
],
|
||||
"tool_choice": "required",
|
||||
"tools": [tool],
|
||||
"parallel_tool_calls": False,
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
assert isinstance(actual_arguments, str)
|
||||
if argument_key is not None:
|
||||
actual_arguments = json.loads(actual_arguments)
|
||||
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
||||
|
||||
|
||||
def do_test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "say hello world with python"},
|
||||
],
|
||||
"tools": tools if tools else None,
|
||||
"tool_choice": tool_choice,
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
|
||||
])
|
||||
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||
do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||
("meetkai-functionary-medium-v3.2", 256, [], None),
|
||||
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
|
||||
("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'),
|
||||
("meetkai-functionary-medium-v3.1", 256, [], None),
|
||||
("meetkai-functionary-medium-v3.1", 256, [TEST_TOOL], None),
|
||||
("meetkai-functionary-medium-v3.1", 256, [PYTHON_TOOL], 'none'),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 256, [], None),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
|
||||
])
|
||||
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||
do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||
("bartowski/c4ai-command-r7b-12-2024-GGUF:Q4_K_M", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")),
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
|
||||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
|
||||
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None):
|
||||
global server
|
||||
n_predict = 512
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = n_predict
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather in Istanbul?"},
|
||||
],
|
||||
"tools": [WEATHER_TOOL],
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"]
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
|
||||
location = actual_arguments["location"]
|
||||
assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
|
||||
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
|
||||
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
|
||||
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||
|
||||
(None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
('{"code":"print("}', "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
|
||||
(None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
|
||||
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
global server
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = 128
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 256,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "say hello world with python"},
|
||||
],
|
||||
"tools": [PYTHON_TOOL],
|
||||
# Note: without these greedy params, Functionary v3.2 writes `def hello_world():\n print("Hello, World!")\nhello_world()` which is correct but a pain to test.
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
if expected_arguments_override is not None:
|
||||
assert actual_arguments == expected_arguments_override
|
||||
else:
|
||||
actual_arguments = json.loads(actual_arguments)
|
||||
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
|
||||
code = actual_arguments["code"]
|
||||
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
|
||||
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
|
||||
@@ -26,7 +26,7 @@ from re import RegexFlag
|
||||
import wget
|
||||
|
||||
|
||||
DEFAULT_HTTP_TIMEOUT = 10 if "LLAMA_SANITIZE" not in os.environ else 30
|
||||
DEFAULT_HTTP_TIMEOUT = 12 if "LLAMA_SANITIZE" not in os.environ else 30
|
||||
|
||||
|
||||
class ServerResponse:
|
||||
@@ -41,7 +41,7 @@ class ServerProcess:
|
||||
server_port: int = 8080
|
||||
server_host: str = "127.0.0.1"
|
||||
model_hf_repo: str = "ggml-org/models"
|
||||
model_hf_file: str = "tinyllamas/stories260K.gguf"
|
||||
model_hf_file: str | None = "tinyllamas/stories260K.gguf"
|
||||
model_alias: str = "tinyllama-2"
|
||||
temperature: float = 0.8
|
||||
seed: int = 42
|
||||
@@ -191,7 +191,7 @@ class ServerProcess:
|
||||
creationflags=flags,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stdout,
|
||||
env={**os.environ, "LLAMA_CACHE": "tmp"},
|
||||
env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None,
|
||||
)
|
||||
server_instances.add(self)
|
||||
|
||||
|
||||
@@ -5,10 +5,6 @@
|
||||
#include "llama.h"
|
||||
#include "common/base64.hpp"
|
||||
|
||||
#ifndef NDEBUG
|
||||
// crash the server in debug mode, otherwise send an http 500 error
|
||||
#define CPPHTTPLIB_NO_EXCEPTIONS 1
|
||||
#endif
|
||||
// increase max payload length to allow use of larger context size
|
||||
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
||||
#include "httplib.h"
|
||||
@@ -17,6 +13,7 @@
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
#include "minja.hpp"
|
||||
#include "chat.hpp"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#include <random>
|
||||
@@ -376,7 +373,7 @@ inline std::string format_chat(const common_chat_template & tmpl, const std::vec
|
||||
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
|
||||
}
|
||||
|
||||
chat.push_back({role, content});
|
||||
chat.push_back({role, content, /* tool_calls= */ {}});
|
||||
}
|
||||
|
||||
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
|
||||
@@ -580,21 +577,30 @@ static json oaicompat_completion_params_parse(const json & body) {
|
||||
|
||||
static json oaicompat_completion_params_parse(
|
||||
const json & body, /* openai api json semantics */
|
||||
const common_chat_template & tmpl,
|
||||
bool use_jinja)
|
||||
bool use_jinja,
|
||||
const common_chat_templates & chat_templates)
|
||||
{
|
||||
json llama_params;
|
||||
const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use
|
||||
? *chat_templates.template_tool_use
|
||||
: *chat_templates.template_default;
|
||||
|
||||
auto tools = json_value(body, "tools", json());
|
||||
auto has_tools = tools.is_array() && !tools.empty();
|
||||
auto stream = json_value(body, "stream", false);
|
||||
|
||||
if (has_tools) {
|
||||
if (use_jinja) {
|
||||
LOG_WRN("tools param is not fully supported yet\n");
|
||||
} else {
|
||||
if (tools.is_array() && !tools.empty()) {
|
||||
if (stream) {
|
||||
throw std::runtime_error("Cannot use tools with stream");
|
||||
}
|
||||
if (!use_jinja) {
|
||||
throw std::runtime_error("tools param requires --jinja flag");
|
||||
}
|
||||
}
|
||||
if (!use_jinja) {
|
||||
if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
|
||||
throw std::runtime_error("Unsupported param: tool_choice");
|
||||
}
|
||||
}
|
||||
|
||||
// Handle "stop" field
|
||||
if (body.contains("stop") && body.at("stop").is_string()) {
|
||||
@@ -619,7 +625,43 @@ static json oaicompat_completion_params_parse(
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
if (use_jinja) {
|
||||
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
||||
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
|
||||
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
|
||||
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
|
||||
}
|
||||
if (tool_choice != "none" && llama_params.contains("grammar")) {
|
||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||
}
|
||||
common_chat_inputs inputs;
|
||||
inputs.messages = body.at("messages");
|
||||
inputs.tools = tools;
|
||||
inputs.tool_choice = tool_choice;
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
||||
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
||||
inputs.parallel_tool_calls = false;
|
||||
}
|
||||
inputs.stream = stream;
|
||||
// TODO: support mixing schema w/ tools beyond generic format.
|
||||
inputs.json_schema = json_value(llama_params, "json_schema", json());
|
||||
auto chat_params = common_chat_params_init(tmpl, inputs);
|
||||
|
||||
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
||||
llama_params["prompt"] = chat_params.prompt;
|
||||
llama_params["grammar"] = chat_params.grammar;
|
||||
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
|
||||
auto grammar_triggers = json::array();
|
||||
for (const auto & trigger : chat_params.grammar_triggers) {
|
||||
grammar_triggers.push_back({
|
||||
{"word", trigger.word},
|
||||
{"at_start", trigger.at_start},
|
||||
});
|
||||
}
|
||||
llama_params["grammar_triggers"] = grammar_triggers;
|
||||
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
|
||||
for (const auto & stop : chat_params.additional_stops) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
} else {
|
||||
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
|
||||
}
|
||||
@@ -638,14 +680,6 @@ static json oaicompat_completion_params_parse(
|
||||
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
||||
}
|
||||
|
||||
// Params supported by OAI but unsupported by llama.cpp
|
||||
static const std::vector<std::string> unsupported_params { "tool_choice" };
|
||||
for (const auto & param : unsupported_params) {
|
||||
if (body.contains(param)) {
|
||||
throw std::runtime_error("Unsupported param: " + param);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy remaining properties to llama_params
|
||||
// This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint.
|
||||
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
|
||||
|
||||
@@ -141,6 +141,7 @@
|
||||
:msg="pendingMsg"
|
||||
:key="pendingMsg.id"
|
||||
:is-generating="isGenerating"
|
||||
:show-thought-in-progress="config.showThoughtInProgress"
|
||||
:edit-user-msg-and-regenerate="() => {}"
|
||||
:regenerate-msg="() => {}"></message-bubble>
|
||||
</div>
|
||||
@@ -153,8 +154,6 @@
|
||||
placeholder="Type a message (Shift+Enter to add a new line)"
|
||||
v-model="inputMsg"
|
||||
@keydown.enter.exact.prevent="sendMessage"
|
||||
@keydown.enter.shift.exact.prevent="inputMsg += '\n'"
|
||||
:disabled="isGenerating"
|
||||
id="msg-input"
|
||||
dir="auto"
|
||||
></textarea>
|
||||
@@ -202,6 +201,20 @@
|
||||
</template>
|
||||
</div>
|
||||
</details>
|
||||
<!-- Section: Reasoning models -->
|
||||
<details class="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
|
||||
<summary class="collapse-title font-bold">Reasoning models</summary>
|
||||
<div class="collapse-content">
|
||||
<div class="flex flex-row items-center mb-2">
|
||||
<input type="checkbox" class="checkbox" v-model="config.showThoughtInProgress" />
|
||||
<span class="ml-4">Expand though process by default for generating message</span>
|
||||
</div>
|
||||
<div class="flex flex-row items-center mb-2">
|
||||
<input type="checkbox" class="checkbox" v-model="config.excludeThoughtOnReq" />
|
||||
<span class="ml-4">Exclude thought process when sending request to API (Recommended for DeepSeek-R1)</span>
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
<!-- Section: Advanced config -->
|
||||
<details class="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
|
||||
<summary class="collapse-title font-bold">Advanced config</summary>
|
||||
@@ -261,7 +274,17 @@
|
||||
<span v-if="msg.content === null" class="loading loading-dots loading-md"></span>
|
||||
<!-- render message as markdown -->
|
||||
<div v-else dir="auto">
|
||||
<vue-markdown :source="msg.content"></vue-markdown>
|
||||
<details v-if="msg.role === 'assistant' && splitMsgContent.cot" class="collapse bg-base-200 collapse-arrow mb-4" :open="splitMsgContent.isThinking && showThoughtInProgress">
|
||||
<summary class="collapse-title">
|
||||
<span v-if="splitMsgContent.isThinking">
|
||||
<span v-if="isGenerating" class="loading loading-spinner loading-md mr-2" style="vertical-align: middle;"></span>
|
||||
<b>Thinking</b>
|
||||
</span>
|
||||
<b v-else>Thought Process</b>
|
||||
</summary>
|
||||
<vue-markdown :source="splitMsgContent.cot" dir="auto" class="collapse-content"></vue-markdown>
|
||||
</details>
|
||||
<vue-markdown :source="splitMsgContent.content"></vue-markdown>
|
||||
</div>
|
||||
<!-- render timings if enabled -->
|
||||
<div class="dropdown dropdown-hover dropdown-top mt-2" v-if="timings && config.showTokensPerSecond">
|
||||
|
||||
@@ -17,6 +17,11 @@ import { asyncIterator } from '@sec-ant/readable-stream/ponyfill/asyncIterator';
|
||||
|
||||
const isDev = import.meta.env.MODE === 'development';
|
||||
|
||||
// types
|
||||
/** @typedef {{ id: number, role: 'user' | 'assistant', content: string, timings: any }} Message */
|
||||
/** @typedef {{ role: 'user' | 'assistant', content: string }} APIMessage */
|
||||
/** @typedef {{ id: string, lastModified: number, messages: Array<Message> }} Conversation */
|
||||
|
||||
// utility functions
|
||||
const isString = (x) => !!x.toLowerCase;
|
||||
const isBoolean = (x) => x === true || x === false;
|
||||
@@ -50,6 +55,8 @@ const CONFIG_DEFAULT = {
|
||||
apiKey: '',
|
||||
systemMessage: 'You are a helpful assistant.',
|
||||
showTokensPerSecond: false,
|
||||
showThoughtInProgress: false,
|
||||
excludeThoughtOnReq: true,
|
||||
// make sure these default values are in sync with `common.h`
|
||||
samplers: 'edkypmxt',
|
||||
temperature: 0.8,
|
||||
@@ -172,6 +179,7 @@ const MessageBubble = defineComponent({
|
||||
config: Object,
|
||||
msg: Object,
|
||||
isGenerating: Boolean,
|
||||
showThoughtInProgress: Boolean,
|
||||
editUserMsgAndRegenerate: Function,
|
||||
regenerateMsg: Function,
|
||||
},
|
||||
@@ -188,7 +196,31 @@ const MessageBubble = defineComponent({
|
||||
prompt_per_second: this.msg.timings.prompt_n / (this.msg.timings.prompt_ms / 1000),
|
||||
predicted_per_second: this.msg.timings.predicted_n / (this.msg.timings.predicted_ms / 1000),
|
||||
};
|
||||
}
|
||||
},
|
||||
splitMsgContent() {
|
||||
const content = this.msg.content;
|
||||
if (this.msg.role !== 'assistant') {
|
||||
return { content };
|
||||
}
|
||||
let actualContent = '';
|
||||
let cot = '';
|
||||
let isThinking = false;
|
||||
let thinkSplit = content.split('<think>', 2);
|
||||
actualContent += thinkSplit[0];
|
||||
while (thinkSplit[1] !== undefined) {
|
||||
// <think> tag found
|
||||
thinkSplit = thinkSplit[1].split('</think>', 2);
|
||||
cot += thinkSplit[0];
|
||||
isThinking = true;
|
||||
if (thinkSplit[1] !== undefined) {
|
||||
// </think> closing tag found
|
||||
isThinking = false;
|
||||
thinkSplit = thinkSplit[1].split('<think>', 2);
|
||||
actualContent += thinkSplit[0];
|
||||
}
|
||||
}
|
||||
return { content: actualContent, cot, isThinking };
|
||||
},
|
||||
},
|
||||
methods: {
|
||||
copyMsg() {
|
||||
@@ -208,7 +240,10 @@ const MessageBubble = defineComponent({
|
||||
// format: { [convId]: { id: string, lastModified: number, messages: [...] } }
|
||||
// convId is a string prefixed with 'conv-'
|
||||
const StorageUtils = {
|
||||
// manage conversations
|
||||
/**
|
||||
* manage conversations
|
||||
* @returns {Array<Conversation>}
|
||||
*/
|
||||
getAllConversations() {
|
||||
const res = [];
|
||||
for (const key in localStorage) {
|
||||
@@ -219,11 +254,19 @@ const StorageUtils = {
|
||||
res.sort((a, b) => b.lastModified - a.lastModified);
|
||||
return res;
|
||||
},
|
||||
// can return null if convId does not exist
|
||||
/**
|
||||
* can return null if convId does not exist
|
||||
* @param {string} convId
|
||||
* @returns {Conversation | null}
|
||||
*/
|
||||
getOneConversation(convId) {
|
||||
return JSON.parse(localStorage.getItem(convId) || 'null');
|
||||
},
|
||||
// if convId does not exist, create one
|
||||
/**
|
||||
* if convId does not exist, create one
|
||||
* @param {string} convId
|
||||
* @param {Message} msg
|
||||
*/
|
||||
appendMsg(convId, msg) {
|
||||
if (msg.content === null) return;
|
||||
const conv = StorageUtils.getOneConversation(convId) || {
|
||||
@@ -235,12 +278,24 @@ const StorageUtils = {
|
||||
conv.lastModified = Date.now();
|
||||
localStorage.setItem(convId, JSON.stringify(conv));
|
||||
},
|
||||
/**
|
||||
* Get new conversation id
|
||||
* @returns {string}
|
||||
*/
|
||||
getNewConvId() {
|
||||
return `conv-${Date.now()}`;
|
||||
},
|
||||
/**
|
||||
* remove conversation by id
|
||||
* @param {string} convId
|
||||
*/
|
||||
remove(convId) {
|
||||
localStorage.removeItem(convId);
|
||||
},
|
||||
/**
|
||||
* remove all conversations
|
||||
* @param {string} convId
|
||||
*/
|
||||
filterAndKeepMsgs(convId, predicate) {
|
||||
const conv = StorageUtils.getOneConversation(convId);
|
||||
if (!conv) return;
|
||||
@@ -248,6 +303,11 @@ const StorageUtils = {
|
||||
conv.lastModified = Date.now();
|
||||
localStorage.setItem(convId, JSON.stringify(conv));
|
||||
},
|
||||
/**
|
||||
* remove last message from conversation
|
||||
* @param {string} convId
|
||||
* @returns {Message | undefined}
|
||||
*/
|
||||
popMsg(convId) {
|
||||
const conv = StorageUtils.getOneConversation(convId);
|
||||
if (!conv) return;
|
||||
@@ -322,10 +382,12 @@ const mainApp = createApp({
|
||||
data() {
|
||||
return {
|
||||
conversations: StorageUtils.getAllConversations(),
|
||||
messages: [], // { id: number, role: 'user' | 'assistant', content: string }
|
||||
/** @type {Array<Message>} */
|
||||
messages: [],
|
||||
viewingConvId: StorageUtils.getNewConvId(),
|
||||
inputMsg: '',
|
||||
isGenerating: false,
|
||||
/** @type {Array<Message> | null} */
|
||||
pendingMsg: null, // the on-going message from assistant
|
||||
stopGeneration: () => {},
|
||||
selectedTheme: StorageUtils.getTheme(),
|
||||
@@ -333,6 +395,7 @@ const mainApp = createApp({
|
||||
showConfigDialog: false,
|
||||
// const
|
||||
themes: THEMES,
|
||||
/** @type {CONFIG_DEFAULT} */
|
||||
configDefault: {...CONFIG_DEFAULT},
|
||||
configInfo: {...CONFIG_INFO},
|
||||
isDev,
|
||||
@@ -405,7 +468,10 @@ const mainApp = createApp({
|
||||
URL.revokeObjectURL(url);
|
||||
},
|
||||
async sendMessage() {
|
||||
if (!this.inputMsg) return;
|
||||
// prevent sending empty message
|
||||
// also allow typing the message while generating, but does not allow sending it (to match UX/UI behavior of other chat apps)
|
||||
if (!this.inputMsg || this.isGenerating) return;
|
||||
|
||||
const currConvId = this.viewingConvId;
|
||||
|
||||
StorageUtils.appendMsg(currConvId, {
|
||||
@@ -425,42 +491,50 @@ const mainApp = createApp({
|
||||
this.isGenerating = true;
|
||||
|
||||
try {
|
||||
/** @type {CONFIG_DEFAULT} */
|
||||
const config = this.config;
|
||||
const abortController = new AbortController();
|
||||
this.stopGeneration = () => abortController.abort();
|
||||
/** @type {Array<APIMessage>} */
|
||||
let messages = [
|
||||
{ role: 'system', content: config.systemMessage },
|
||||
...normalizeMsgsForAPI(this.messages),
|
||||
];
|
||||
if (config.excludeThoughtOnReq) {
|
||||
messages = filterThoughtFromMsgs(messages);
|
||||
}
|
||||
if (isDev) console.log({messages});
|
||||
const params = {
|
||||
messages: [
|
||||
{ role: 'system', content: this.config.systemMessage },
|
||||
...this.messages,
|
||||
],
|
||||
messages,
|
||||
stream: true,
|
||||
cache_prompt: true,
|
||||
samplers: this.config.samplers,
|
||||
temperature: this.config.temperature,
|
||||
dynatemp_range: this.config.dynatemp_range,
|
||||
dynatemp_exponent: this.config.dynatemp_exponent,
|
||||
top_k: this.config.top_k,
|
||||
top_p: this.config.top_p,
|
||||
min_p: this.config.min_p,
|
||||
typical_p: this.config.typical_p,
|
||||
xtc_probability: this.config.xtc_probability,
|
||||
xtc_threshold: this.config.xtc_threshold,
|
||||
repeat_last_n: this.config.repeat_last_n,
|
||||
repeat_penalty: this.config.repeat_penalty,
|
||||
presence_penalty: this.config.presence_penalty,
|
||||
frequency_penalty: this.config.frequency_penalty,
|
||||
dry_multiplier: this.config.dry_multiplier,
|
||||
dry_base: this.config.dry_base,
|
||||
dry_allowed_length: this.config.dry_allowed_length,
|
||||
dry_penalty_last_n: this.config.dry_penalty_last_n,
|
||||
max_tokens: this.config.max_tokens,
|
||||
timings_per_token: !!this.config.showTokensPerSecond,
|
||||
...(this.config.custom.length ? JSON.parse(this.config.custom) : {}),
|
||||
samplers: config.samplers,
|
||||
temperature: config.temperature,
|
||||
dynatemp_range: config.dynatemp_range,
|
||||
dynatemp_exponent: config.dynatemp_exponent,
|
||||
top_k: config.top_k,
|
||||
top_p: config.top_p,
|
||||
min_p: config.min_p,
|
||||
typical_p: config.typical_p,
|
||||
xtc_probability: config.xtc_probability,
|
||||
xtc_threshold: config.xtc_threshold,
|
||||
repeat_last_n: config.repeat_last_n,
|
||||
repeat_penalty: config.repeat_penalty,
|
||||
presence_penalty: config.presence_penalty,
|
||||
frequency_penalty: config.frequency_penalty,
|
||||
dry_multiplier: config.dry_multiplier,
|
||||
dry_base: config.dry_base,
|
||||
dry_allowed_length: config.dry_allowed_length,
|
||||
dry_penalty_last_n: config.dry_penalty_last_n,
|
||||
max_tokens: config.max_tokens,
|
||||
timings_per_token: !!config.showTokensPerSecond,
|
||||
...(config.custom.length ? JSON.parse(config.custom) : {}),
|
||||
};
|
||||
const chunks = sendSSEPostRequest(`${BASE_URL}/v1/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...(this.config.apiKey ? {'Authorization': `Bearer ${this.config.apiKey}`} : {})
|
||||
...(config.apiKey ? {'Authorization': `Bearer ${config.apiKey}`} : {})
|
||||
},
|
||||
body: JSON.stringify(params),
|
||||
signal: abortController.signal,
|
||||
@@ -477,7 +551,7 @@ const mainApp = createApp({
|
||||
};
|
||||
}
|
||||
const timings = chunk.timings;
|
||||
if (timings && this.config.showTokensPerSecond) {
|
||||
if (timings && config.showTokensPerSecond) {
|
||||
// only extract what's really needed, to save some space
|
||||
this.pendingMsg.timings = {
|
||||
prompt_n: timings.prompt_n,
|
||||
@@ -598,3 +672,33 @@ try {
|
||||
<button class="btn" onClick="localStorage.clear(); window.location.reload();">Clear localStorage</button>
|
||||
</div>`;
|
||||
}
|
||||
|
||||
/**
|
||||
* filter out redundant fields upon sending to API
|
||||
* @param {Array<APIMessage>} messages
|
||||
* @returns {Array<APIMessage>}
|
||||
*/
|
||||
function normalizeMsgsForAPI(messages) {
|
||||
return messages.map((msg) => {
|
||||
return {
|
||||
role: msg.role,
|
||||
content: msg.content,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* recommended for DeepsSeek-R1, filter out content between <think> and </think> tags
|
||||
* @param {Array<APIMessage>} messages
|
||||
* @returns {Array<APIMessage>}
|
||||
*/
|
||||
function filterThoughtFromMsgs(messages) {
|
||||
return messages.map((msg) => {
|
||||
return {
|
||||
role: msg.role,
|
||||
content: msg.role === 'assistant'
|
||||
? msg.content.split('</think>').at(-1).trim()
|
||||
: msg.content,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
11
examples/simple-cmake-pkg/CMakeLists.txt
Normal file
11
examples/simple-cmake-pkg/CMakeLists.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
project(llama-simple-cmake-pkg)
|
||||
|
||||
set(TARGET llama-simple-cmake-pkg)
|
||||
|
||||
find_package(Llama REQUIRED)
|
||||
|
||||
add_executable(${TARGET} ${CMAKE_CURRENT_LIST_DIR}/../simple/simple.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama ggml::all ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
34
examples/simple-cmake-pkg/README.md
Normal file
34
examples/simple-cmake-pkg/README.md
Normal file
@@ -0,0 +1,34 @@
|
||||
# llama.cpp/example/simple-cmake-pkg
|
||||
|
||||
This program builds [simple](../simple) using a relocatable CMake package. It serves as an example of using the `find_package()` CMake command to conveniently include [llama.cpp](https://github.com/ggerganov/llama.cpp) in projects which live outside of the source tree.
|
||||
|
||||
## Building
|
||||
|
||||
Because this example is "outside of the source tree", it is important to first build/install llama.cpp using CMake. An example is provided here, but please see the [llama.cpp build instructions](../..) for more detailed build instructions.
|
||||
|
||||
### Considerations
|
||||
|
||||
When hardware acceleration libraries are used (e.g. CUDA, Metal, Vulkan, etc.), the appropriate dependencies will be searched for automatically. So, for example, when finding a package
|
||||
|
||||
### Build llama.cpp and install to llama.cpp/inst
|
||||
|
||||
```sh
|
||||
git clone https://github.com/ggerganov/llama.cpp
|
||||
cd llama.cpp
|
||||
cmake -S . -B build
|
||||
cmake --build build
|
||||
cmake --install build --prefix inst
|
||||
|
||||
### Build simple-cmake-pkg
|
||||
|
||||
```sh
|
||||
cd examples/simple-cmake-pkg
|
||||
cmake -S . -B build -DCMAKE_PREFIX_PATH=../../inst/lib/cmake
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
### Run simple-cmake-pkg
|
||||
|
||||
```sh
|
||||
./build/llama-simple-cmake-pkg -m ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is"
|
||||
```
|
||||
@@ -58,7 +58,8 @@ else()
|
||||
set(GGML_BLAS_VENDOR_DEFAULT "Generic")
|
||||
endif()
|
||||
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
if (CMAKE_CROSSCOMPILING OR DEFINED ENV{SOURCE_DATE_EPOCH})
|
||||
message(STATUS "Setting GGML_NATIVE_DEFAULT to OFF")
|
||||
set(GGML_NATIVE_DEFAULT OFF)
|
||||
else()
|
||||
set(GGML_NATIVE_DEFAULT ON)
|
||||
@@ -153,6 +154,8 @@ option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashA
|
||||
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
|
||||
|
||||
option(GGML_HIP "ggml: use HIP" OFF)
|
||||
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
|
||||
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
|
||||
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
|
||||
option(GGML_VULKAN "ggml: use Vulkan" OFF)
|
||||
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
|
||||
@@ -264,3 +267,77 @@ if (GGML_STANDALONE)
|
||||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
|
||||
DESTINATION share/pkgconfig)
|
||||
endif()
|
||||
|
||||
#
|
||||
# Create CMake package
|
||||
#
|
||||
|
||||
# Generate version info based on git commit.
|
||||
|
||||
if(NOT DEFINED GGML_BUILD_NUMBER)
|
||||
find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH)
|
||||
execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE GGML_BUILD_NUMBER
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
if(GGML_BUILD_NUMBER EQUAL 1)
|
||||
message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE GGML_BUILD_COMMIT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
# Capture variables prefixed with GGML_.
|
||||
|
||||
set(variable_set_statements
|
||||
"
|
||||
####### Expanded from @GGML_VARIABLES_EXPANED@ by configure_package_config_file() #######
|
||||
####### Any changes to this file will be overwritten by the next CMake run #######
|
||||
|
||||
")
|
||||
|
||||
set(GGML_SHARED_LIB ${BUILD_SHARED_LIBS})
|
||||
|
||||
get_cmake_property(all_variables VARIABLES)
|
||||
foreach(variable_name IN LISTS all_variables)
|
||||
if(variable_name MATCHES "^GGML_")
|
||||
string(REPLACE ";" "\\;"
|
||||
variable_value "${${variable_name}}")
|
||||
|
||||
set(variable_set_statements
|
||||
"${variable_set_statements}set(${variable_name} \"${variable_value}\")\n")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(GGML_VARIABLES_EXPANDED ${variable_set_statements})
|
||||
|
||||
# Create the CMake package and set install location.
|
||||
|
||||
set(GGML_INSTALL_VERSION 0.0.${GGML_BUILD_NUMBER})
|
||||
set(GGML_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files")
|
||||
set(GGML_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
|
||||
set(GGML_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
|
||||
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/ggml-config.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
|
||||
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml
|
||||
PATH_VARS GGML_INCLUDE_INSTALL_DIR
|
||||
GGML_LIB_INSTALL_DIR
|
||||
GGML_BIN_INSTALL_DIR)
|
||||
|
||||
write_basic_package_version_file(
|
||||
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
|
||||
VERSION ${GGML_INSTALL_VERSION}
|
||||
COMPATIBILITY SameMajorVersion)
|
||||
|
||||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
|
||||
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml)
|
||||
|
||||
147
ggml/cmake/ggml-config.cmake.in
Normal file
147
ggml/cmake/ggml-config.cmake.in
Normal file
@@ -0,0 +1,147 @@
|
||||
|
||||
@GGML_VARIABLES_EXPANDED@
|
||||
|
||||
@PACKAGE_INIT@
|
||||
|
||||
set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@")
|
||||
set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@")
|
||||
set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
find_library(GGML_LIBRARY ggml
|
||||
REQUIRED
|
||||
HINTS ${GGML_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
add_library(ggml::ggml UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::ggml
|
||||
PROPERTIES
|
||||
IMPORTED_LOCATION "${GGML_LIBRARY}")
|
||||
|
||||
find_library(GGML_BASE_LIBRARY ggml-base
|
||||
REQUIRED
|
||||
HINTS ${GGML_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
add_library(ggml::ggml-base UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::ggml-base
|
||||
PROPERTIES
|
||||
IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")
|
||||
|
||||
if (NOT GGML_SHARED_LIB)
|
||||
if (APPLE AND GGML_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK})
|
||||
endif()
|
||||
|
||||
if (GGML_OPENMP)
|
||||
find_package(OpenMP REQUIRED)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
||||
endif()
|
||||
|
||||
if (GGML_CPU_HBM)
|
||||
find_library(memkind memkind REQUIRED)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind)
|
||||
endif()
|
||||
|
||||
if (GGML_BLAS)
|
||||
find_package(BLAS REQUIRED)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS})
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
endif()
|
||||
|
||||
if (GGML_METAL)
|
||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||
|
||||
list(APPEND GGML_METAL_INTERFACE_LINK_LIBRARIES
|
||||
${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
|
||||
endif()
|
||||
|
||||
if (GGML_VULKAN)
|
||||
find_package(Vulkan REQUIRED)
|
||||
list(APPEND GGML_VULKAN_INTERFACE_LINK_LIBRARIES Vulkan::Vulkan)
|
||||
endif()
|
||||
|
||||
if (GGML_HIP)
|
||||
find_package(hip REQUIRED)
|
||||
find_package(hipblas REQUIRED)
|
||||
find_package(rocblas REQUIRED)
|
||||
list(APPEND GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas)
|
||||
endif()
|
||||
|
||||
if (GGML_SYCL)
|
||||
find_package(DNNL)
|
||||
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl)
|
||||
endif()
|
||||
if (WIN32)
|
||||
find_package(IntelSYCL REQUIRED)
|
||||
find_package(MKL REQUIRED)
|
||||
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(_ggml_all_targets "")
|
||||
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
|
||||
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
|
||||
string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
|
||||
|
||||
find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
|
||||
REQUIRED
|
||||
HINTS ${GGML_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
|
||||
|
||||
add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
|
||||
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
|
||||
IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
|
||||
INTERFACE_COMPILE_FEATURES c_std_90
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
|
||||
if(is_cpu_variant)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base")
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
|
||||
|
||||
if(GGML_CPU_INTERFACE_LINK_OPTIONS)
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
|
||||
endif()
|
||||
|
||||
else()
|
||||
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base")
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
|
||||
|
||||
if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
list(APPEND _ggml_all_targets ggml::${_ggml_backend})
|
||||
endforeach()
|
||||
|
||||
add_library(ggml::all INTERFACE IMPORTED)
|
||||
set_target_properties(ggml::all
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}")
|
||||
|
||||
check_required_components(ggml)
|
||||
@@ -1775,7 +1775,7 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
int k);
|
||||
|
||||
#define GGML_KQ_MASK_PAD 32
|
||||
#define GGML_KQ_MASK_PAD 64
|
||||
|
||||
// q: [n_embd, n_batch, n_head, 1]
|
||||
// k: [n_embd, n_kv, n_head_kv, 1]
|
||||
|
||||
@@ -93,12 +93,18 @@ endif()
|
||||
|
||||
if (GGML_CCACHE)
|
||||
find_program(GGML_CCACHE_FOUND ccache)
|
||||
find_program(GGML_SCCACHE_FOUND sccache)
|
||||
|
||||
if (GGML_CCACHE_FOUND)
|
||||
if (GGML_CCACHE_FOUND OR GGML_SCCACHE_FOUND)
|
||||
if(GGML_CCACHE_FOUND)
|
||||
set(GGML_CCACHE_VARIANT ccache)
|
||||
else()
|
||||
set(GGML_CCACHE_VARIANT sccache)
|
||||
endif()
|
||||
# TODO: should not be set globally
|
||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
|
||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}")
|
||||
set(ENV{CCACHE_SLOPPINESS} time_macros)
|
||||
message(STATUS "ccache found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
|
||||
message(STATUS "${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
|
||||
else()
|
||||
message(STATUS "Warning: ccache not found - consider installing it for faster compilation or disable this warning with GGML_CCACHE=OFF")
|
||||
endif ()
|
||||
@@ -250,6 +256,17 @@ function(ggml_add_backend_library backend)
|
||||
target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD)
|
||||
target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED)
|
||||
endif()
|
||||
|
||||
if(NOT GGML_AVAILABLE_BACKENDS)
|
||||
set(GGML_AVAILABLE_BACKENDS "${backend}"
|
||||
CACHE INTERNAL "List of backends for cmake package")
|
||||
else()
|
||||
list(FIND GGML_AVAILABLE_BACKENDS "${backend}" has_backend)
|
||||
if(has_backend EQUAL -1)
|
||||
set(GGML_AVAILABLE_BACKENDS "${GGML_AVAILABLE_BACKENDS};${backend}"
|
||||
CACHE INTERNAL "List of backends for cmake package")
|
||||
endif()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
function(ggml_add_backend backend)
|
||||
@@ -297,7 +314,7 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||
# MSVC doesn't support AMX
|
||||
ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
||||
endif()
|
||||
else ()
|
||||
elseif (GGML_CPU)
|
||||
ggml_add_cpu_backend_variant_impl("")
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1302,7 +1302,7 @@ struct ggml_threadpool {
|
||||
// these are atomic as an annotation for thread-sanitizer
|
||||
atomic_bool stop; // Used for stopping the threadpool altogether
|
||||
atomic_bool pause; // Used for pausing the threadpool or individual threads
|
||||
atomic_bool abort; // Used for aborting processing of a graph
|
||||
atomic_int abort; // Used for aborting processing of a graph
|
||||
|
||||
struct ggml_compute_state * workers; // per thread state
|
||||
int n_threads_max; // number of threads in the pool
|
||||
@@ -7883,7 +7883,7 @@ static void ggml_compute_forward_out_prod_f32(
|
||||
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
|
||||
}
|
||||
@@ -7892,7 +7892,7 @@ static void ggml_compute_forward_out_prod_f32(
|
||||
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
||||
}
|
||||
@@ -13851,14 +13851,14 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
/*.threadpool=*/ tp,
|
||||
};
|
||||
|
||||
for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
|
||||
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
|
||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||
|
||||
ggml_compute_forward(¶ms, node);
|
||||
|
||||
if (state->ith == 0 && cplan->abort_callback &&
|
||||
cplan->abort_callback(cplan->abort_callback_data)) {
|
||||
tp->abort = true;
|
||||
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
|
||||
tp->ec = GGML_STATUS_ABORTED;
|
||||
}
|
||||
|
||||
@@ -14031,7 +14031,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl(
|
||||
threadpool->current_chunk = 0;
|
||||
threadpool->stop = false;
|
||||
threadpool->pause = tpp->paused;
|
||||
threadpool->abort = false;
|
||||
threadpool->abort = -1;
|
||||
threadpool->workers = NULL;
|
||||
threadpool->n_threads_max = tpp->n_threads;
|
||||
threadpool->n_threads_cur = tpp->n_threads;
|
||||
@@ -14110,7 +14110,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
|
||||
threadpool->cgraph = cgraph;
|
||||
threadpool->cplan = cplan;
|
||||
threadpool->current_chunk = 0;
|
||||
threadpool->abort = false;
|
||||
threadpool->abort = -1;
|
||||
threadpool->ec = GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -416,7 +416,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
|
||||
case GGML_OP_IM2COL_BACK:
|
||||
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
|
||||
case GGML_OP_OUT_PROD:
|
||||
return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
|
||||
return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
|
||||
src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ if (CUDAToolkit_FOUND)
|
||||
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
|
||||
|
||||
file(GLOB GGML_SOURCES_CUDA "*.cu")
|
||||
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
|
||||
file(GLOB SRCS "template-instances/fattn-mma*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/mmq*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
|
||||
@@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
|
||||
|
||||
template <typename T>
|
||||
static __global__ void k_repeat_back(
|
||||
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
|
||||
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
|
||||
|
||||
const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
|
||||
const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
|
||||
const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
|
||||
const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
|
||||
const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
|
||||
const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
|
||||
const int64_t tid2 = tid23 % ne2;
|
||||
const int64_t tid3 = tid23 / ne2;
|
||||
|
||||
if (tid0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
T sum = 0;
|
||||
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
|
||||
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
|
||||
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
|
||||
sum += src[i2*ne01*ne00 + i1*ne00 + i0];
|
||||
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
|
||||
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
|
||||
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
|
||||
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
|
||||
sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
|
||||
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
|
||||
}
|
||||
|
||||
template<float (*bin_op)(const float, const float)>
|
||||
@@ -274,12 +279,14 @@ struct bin_bcast_cuda {
|
||||
|
||||
template <typename T>
|
||||
static void repeat_back_cuda(
|
||||
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
|
||||
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
|
||||
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
|
||||
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
|
||||
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
|
||||
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
|
||||
(src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
|
||||
}
|
||||
|
||||
template<class op>
|
||||
@@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_can_repeat(dst, src0));
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
GGML_ASSERT(src0->ne[3] == 1);
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
GGML_ASSERT(dst->ne[3] == 1);
|
||||
GGML_ASSERT(ne2*ne3 <= (1 << 15));
|
||||
|
||||
const size_t ts = ggml_type_size(src0->type);
|
||||
const size_t s00 = nb00 / ts;
|
||||
const size_t s01 = nb01 / ts;
|
||||
const size_t s02 = nb02 / ts;
|
||||
const size_t s03 = nb03 / ts;
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
|
||||
repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ASSERT(false);
|
||||
|
||||
@@ -46,20 +46,27 @@
|
||||
#define GGML_CUDA_CC_VOLTA 700
|
||||
#define GGML_CUDA_CC_TURING 750
|
||||
#define GGML_CUDA_CC_AMPERE 800
|
||||
#define GGML_CUDA_CC_OFFSET_AMD 1000000
|
||||
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
|
||||
|
||||
// GCN/CNDA, wave size is 64
|
||||
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
||||
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue
|
||||
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a
|
||||
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers
|
||||
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing
|
||||
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 942) // MI300
|
||||
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
||||
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
|
||||
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
|
||||
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
|
||||
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
|
||||
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
|
||||
|
||||
// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
|
||||
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000
|
||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
|
||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
|
||||
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||
|
||||
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
||||
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
||||
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
|
||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||
|
||||
#define GGML_CUDA_CC_QY1 210
|
||||
#define GGML_CUDA_CC_QY2 220
|
||||
@@ -131,6 +138,10 @@ typedef float dfloat; // dequantize float
|
||||
typedef float2 dfloat2;
|
||||
#endif // GGML_CUDA_F16
|
||||
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
|
||||
#define GGML_USE_VMM
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
|
||||
|
||||
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
||||
#define FP16_AVAILABLE
|
||||
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
||||
@@ -144,7 +155,7 @@ typedef float2 dfloat2;
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
#define INT8_MMA_AVAILABLE
|
||||
#define NEW_MMA_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
|
||||
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
@@ -155,14 +166,24 @@ static constexpr bool fast_fp16_available(const int cc) {
|
||||
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
|
||||
}
|
||||
|
||||
// Any FP16 tensor cores are available.
|
||||
static constexpr bool fp16_mma_available(const int cc) {
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
|
||||
}
|
||||
|
||||
static constexpr bool int8_mma_available(const int cc) {
|
||||
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
||||
static constexpr bool new_mma_available(const int cc) {
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
|
||||
}
|
||||
|
||||
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
return __AMDGCN_WAVEFRONT_SIZE;
|
||||
#else
|
||||
return 32;
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
static __device__ void no_device_code(
|
||||
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
||||
@@ -186,53 +207,46 @@ static __device__ void no_device_code(
|
||||
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
|
||||
#endif // __CUDA_ARCH__
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ int warp_reduce_sum(int x) {
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
return __reduce_add_sync(0xffffffff, x);
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x += __shfl_xor_sync(0xffffffff, x, offset, width);
|
||||
}
|
||||
return x;
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x += __shfl_xor_sync(0xffffffff, x, offset, width);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
|
||||
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
|
||||
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
|
||||
reinterpret_cast<half&>(a.x) += __low2half(a_other);
|
||||
reinterpret_cast<half&>(a.y) += __high2half(a_other);
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
|
||||
}
|
||||
return a;
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
|
||||
}
|
||||
return a;
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
@@ -240,10 +254,11 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
|
||||
}
|
||||
return x;
|
||||
}
|
||||
@@ -265,35 +280,34 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
|
||||
#if CUDART_VERSION >= CUDART_HMAX
|
||||
#if defined(GGML_USE_HIP) && HIP_VERSION >= 50700000
|
||||
return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
|
||||
#elif !defined(GGML_USE_HIP) && CUDART_VERSION >= CUDART_HMAX
|
||||
return __hmax2(a, b);
|
||||
#else
|
||||
#elif !defined(GGML_USE_HIP)
|
||||
half2 ret;
|
||||
reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
|
||||
reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
|
||||
return ret;
|
||||
#endif // CUDART_VERSION >= CUDART_HMAX
|
||||
|
||||
#else
|
||||
GGML_UNUSED(a);
|
||||
GGML_UNUSED(b);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#endif
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
|
||||
}
|
||||
return x;
|
||||
#else
|
||||
GGML_UNUSED(x);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
|
||||
}
|
||||
|
||||
#if CUDART_VERSION < CUDART_HMASK
|
||||
@@ -516,6 +530,7 @@ struct ggml_cuda_device_info {
|
||||
bool vmm; // virtual memory support
|
||||
size_t vmm_granularity; // granularity of virtual memory
|
||||
size_t total_vram;
|
||||
int warp_size; // Number of threads in a dispatch
|
||||
};
|
||||
|
||||
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
|
||||
@@ -588,7 +603,7 @@ struct ggml_tensor_extra_gpu {
|
||||
};
|
||||
|
||||
|
||||
#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
|
||||
#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
|
||||
#define USE_CUDA_GRAPH
|
||||
#endif
|
||||
|
||||
|
||||
@@ -516,6 +516,114 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
||||
nullptr;
|
||||
}
|
||||
|
||||
// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||
#endif // __clang__
|
||||
|
||||
template<int D, int ncols, int KQ_stride> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
||||
|
||||
const int iter_k = ne11 / KQ_stride;
|
||||
const int iter_j = (ne01 + (ncols - 1)) / ncols;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
|
||||
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
|
||||
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int channel = kbc0 / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
||||
|
||||
dst += jt*ncols*ne02*D + channel*D;
|
||||
|
||||
// Load the partial result that needs a fixup:
|
||||
float dst_val[ncols] = {0.0f};
|
||||
float max_val[ncols] = {0.0f};
|
||||
float rowsum[ncols] = {0.0f};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (jt*ncols + j >= ne01) {
|
||||
break;
|
||||
}
|
||||
dst_val[j] = dst[j*ne02*D + threadIdx.x];
|
||||
|
||||
const float2 tmp = dst_fixup[bidx0*ncols + j];
|
||||
max_val[j] = tmp.x;
|
||||
rowsum[j] = tmp.y;
|
||||
}
|
||||
|
||||
// Iterate over previous blocks and compute the combined results.
|
||||
// All CUDA blocks that get here must have a previous block that needs a fixup.
|
||||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
continue;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (jt*ncols + j >= ne01) {
|
||||
break;
|
||||
}
|
||||
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
|
||||
|
||||
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
|
||||
|
||||
// Scale the current and new value accumulators depending on the max. values.
|
||||
const float max_val_new = fmaxf(max_val[j], tmp.x);
|
||||
|
||||
const float diff_val = max_val[j] - max_val_new;
|
||||
const float diff_add = tmp.x - max_val_new;
|
||||
|
||||
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
|
||||
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
|
||||
|
||||
dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
|
||||
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
|
||||
|
||||
max_val[j] = max_val_new;
|
||||
}
|
||||
|
||||
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
||||
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
|
||||
break;
|
||||
}
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
}
|
||||
|
||||
// Write back final result:
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (jt*ncols + j >= ne01) {
|
||||
return;
|
||||
}
|
||||
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
||||
template<int D, int parallel_blocks> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
@@ -581,10 +689,11 @@ static void on_no_fattn_vec_case(const int D) {
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int parallel_blocks>
|
||||
// parallel_blocks == 0 is stream-k decomposition
|
||||
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
|
||||
void launch_fattn(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
||||
const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
|
||||
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
|
||||
) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
@@ -603,20 +712,23 @@ void launch_fattn(
|
||||
|
||||
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||
|
||||
GGML_ASSERT(Q->ne[3] == 1);
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
|
||||
ggml_cuda_pool_alloc<half> K_f16(pool);
|
||||
ggml_cuda_pool_alloc<half> V_f16(pool);
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||
|
||||
char * K_data = (char *) K->data;
|
||||
const char * K_data = (const char *) K->data;
|
||||
size_t nb11 = K->nb[1];
|
||||
size_t nb12 = K->nb[2];
|
||||
size_t nb13 = K->nb[3];
|
||||
|
||||
char * V_data = (char *) V->data;
|
||||
const char * V_data = (const char *) V->data;
|
||||
size_t nb21 = V->nb[1];
|
||||
size_t nb22 = V->nb[2];
|
||||
size_t nb23 = V->nb[3];
|
||||
@@ -649,39 +761,60 @@ void launch_fattn(
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
}
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
||||
}
|
||||
const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
|
||||
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
|
||||
|
||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
||||
const int shmem = 0;
|
||||
dim3 blocks_num;
|
||||
if (parallel_blocks == 0) {
|
||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
|
||||
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
|
||||
const bool short_context = K->ne[1] < 4096;
|
||||
|
||||
const int nblocks_stream_k = 2*nsm;
|
||||
|
||||
blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
|
||||
blocks_num.y = 1;
|
||||
blocks_num.z = 1;
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
|
||||
} else {
|
||||
blocks_num.x = parallel_blocks*ntiles_x;
|
||||
blocks_num.y = Q->ne[2];
|
||||
blocks_num.z = Q->ne[3];
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
|
||||
memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
const uint32_t n_head = Q->ne[2];
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||
const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
|
||||
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
||||
(const char *) Q->data,
|
||||
K_data,
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
@@ -693,16 +826,22 @@ void launch_fattn(
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if ((parallel_blocks) == 1) {
|
||||
return;
|
||||
if constexpr (parallel_blocks == 0) {
|
||||
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine = blocks_num;
|
||||
|
||||
flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||
}
|
||||
} else if constexpr (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
||||
|
||||
flash_attn_combine_results<D, parallel_blocks>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||
}
|
||||
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
||||
const int shmem_combine = 0;
|
||||
|
||||
flash_attn_combine_results<D, parallel_blocks>
|
||||
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
637
ggml/src/ggml-cuda/fattn-mma-f16.cuh
Normal file
637
ggml/src/ggml-cuda/fattn-mma-f16.cuh
Normal file
@@ -0,0 +1,637 @@
|
||||
#include "common.cuh"
|
||||
#include "mma.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const float2 * const __restrict__ Q_f2,
|
||||
const half2 * const __restrict__ K_h2,
|
||||
const half2 * const __restrict__ V_h2,
|
||||
const half * const __restrict__ maskh,
|
||||
float2 * const __restrict__ dstk,
|
||||
float2 * const __restrict__ dstk_fixup,
|
||||
const float scale,
|
||||
const float slope,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3,
|
||||
const int jt,
|
||||
const int kb0_start,
|
||||
const int kb0_stop) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
typedef mma_A_I16K8<half2> mma_A;
|
||||
typedef mma_B_J8K8<half2> mma_B;
|
||||
typedef mma_C_I16J8<float> mma_C_KQ;
|
||||
typedef mma_C_I16J8<half2> mma_C_VKQ;
|
||||
|
||||
static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps");
|
||||
constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column.
|
||||
|
||||
static_assert(D % nwarps == 0, "bad D");
|
||||
static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
|
||||
|
||||
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
||||
extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements.
|
||||
|
||||
const int stride_Q = nb01 / sizeof(float2);
|
||||
const int stride_KV = nb11 / sizeof(half2);
|
||||
const int stride_mask = nb31 / sizeof(half);
|
||||
|
||||
mma_B Q_B[D/(2*mma_B::K)];
|
||||
mma_C_VKQ VKQ_C[D/mma_C_VKQ::I];
|
||||
|
||||
float2 KQ_rowsum = {0.0f, 0.0f};
|
||||
float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
|
||||
float2 KQ_max_scale = {0.0f, 0.0f};
|
||||
|
||||
// Temporarily load Q data into tile_KV, will be loaded into registers afterwards.
|
||||
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
||||
const half2 scale_h2 = make_half2(scale, scale);
|
||||
#pragma unroll
|
||||
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
|
||||
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
||||
const int stride_j = WARP_SIZE / stride_k;
|
||||
|
||||
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
|
||||
const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
|
||||
if (jt*ncols + j < ne01) {
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
|
||||
tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
{
|
||||
const int j0 = (threadIdx.y / np) * mma_B::J;
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
|
||||
Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Iterate over ne11 == previous tokens:
|
||||
for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) {
|
||||
const int k_VKQ_0 = kb0*KQ_stride;
|
||||
mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)];
|
||||
|
||||
// Load K data into tile with decreasing granularity for D for better memory bandwidth:
|
||||
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
|
||||
#pragma unroll
|
||||
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
|
||||
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
||||
const int stride_i = WARP_SIZE / stride_k;
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) {
|
||||
const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Calculate tile of KQ:
|
||||
#pragma unroll
|
||||
for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) {
|
||||
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I;
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) {
|
||||
mma_A K_A;
|
||||
K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
|
||||
KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (use_logit_softcap) {
|
||||
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
|
||||
#pragma unroll
|
||||
for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C_KQ::ne; ++l) {
|
||||
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (maskh) {
|
||||
static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size");
|
||||
static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size");
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) {
|
||||
const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C_KQ::ne; ++l) {
|
||||
const int i = i0 + mma_C_KQ::get_i(l);
|
||||
const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l);
|
||||
|
||||
KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate softmax for each KQ column using the current max. value.
|
||||
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
||||
float2 KQ_max_new = KQ_max;
|
||||
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
|
||||
#pragma unroll
|
||||
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) {
|
||||
KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
|
||||
KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
// Values per KQ column are spread across 8 threads, does not need full warp reduce:
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 2; offset >>= 1) {
|
||||
KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
|
||||
KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
|
||||
}
|
||||
|
||||
{
|
||||
const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
|
||||
KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
|
||||
if (diff.x <= SOFTMAX_FTZ_THRESHOLD) {
|
||||
KQ_max_scale.x = 0.0f;
|
||||
}
|
||||
if (diff.y <= SOFTMAX_FTZ_THRESHOLD) {
|
||||
KQ_max_scale.y = 0.0f;
|
||||
}
|
||||
KQ_max = KQ_max_new;
|
||||
}
|
||||
|
||||
float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
|
||||
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
|
||||
#pragma unroll
|
||||
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C_KQ::ne; ++l) {
|
||||
const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y;
|
||||
const float diff = KQ_C[k].x[l] - KQ_max_l;
|
||||
KQ_C[k].x[l] = expf(diff);
|
||||
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
||||
KQ_C[k].x[l] = 0.0f;
|
||||
}
|
||||
|
||||
if (l % 2 == 0) {
|
||||
KQ_rowsum_add.x += KQ_C[k].x[l];
|
||||
} else {
|
||||
KQ_rowsum_add.y += KQ_C[k].x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
||||
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
|
||||
KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y;
|
||||
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D/mma_C_VKQ::I; ++i) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C_VKQ::ne; ++l) {
|
||||
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert KQ C tiles into B tiles for VKQ calculation:
|
||||
mma_B B[KQ_stride/(np*2*mma_B::K)];
|
||||
static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size");
|
||||
#pragma unroll
|
||||
for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) {
|
||||
B[k] = KQ_C[k].to_mma_B();
|
||||
}
|
||||
|
||||
// Load V data into tile with decreasing granularity for D for better memory bandwidth:
|
||||
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
|
||||
#pragma unroll
|
||||
for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
||||
const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i);
|
||||
const int i0_stop = D/2 - (D/2) % (1*stride_i);
|
||||
const int stride_k = WARP_SIZE / stride_i;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) {
|
||||
const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i);
|
||||
|
||||
#pragma unroll
|
||||
for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) {
|
||||
const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i);
|
||||
|
||||
tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Calculate VKQ tile:
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) {
|
||||
static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size");
|
||||
#pragma unroll
|
||||
for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) {
|
||||
const int k0 = k00 + (threadIdx.y % np)*mma_A::K;
|
||||
|
||||
mma_A A;
|
||||
A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
|
||||
VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Finally, sum up partial KQ rowsums.
|
||||
// The partial sums are spread across 8 threads each, does not need full reduce.
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 2; offset >>= 1) {
|
||||
KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE);
|
||||
KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE);
|
||||
}
|
||||
|
||||
// Write VKQ accumulators to shared memory in column-major format.
|
||||
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
||||
// Also for np > 1 the combination is done via these values in shared memory.
|
||||
const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
|
||||
const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format.
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_B::ne; ++l) {
|
||||
const int k = k0 + mma_B::get_k(l);
|
||||
|
||||
tile_KV[j_cwd*D2_padded + k] = B.x[l];
|
||||
}
|
||||
}
|
||||
|
||||
const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset
|
||||
const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
|
||||
const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
|
||||
|
||||
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) {
|
||||
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
||||
((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
static_assert(np == 1 || np == 2 || np == 4, "bad np");
|
||||
if (np == 1) {
|
||||
// No combination is needed, the meta data can be directly written from registers to VRAM.
|
||||
if (needs_fixup && threadIdx.x < mma_B::J) {
|
||||
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
||||
dstk_fixup_meta[j_cwm] = KQ_cmr;
|
||||
}
|
||||
if (is_fixup && threadIdx.x < mma_B::J) {
|
||||
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
||||
dstk_fixup_meta[j_cwm] = KQ_cmr;
|
||||
}
|
||||
} else if (threadIdx.y % np == 0) {
|
||||
// Combine the meta data for parallel warps via shared memory.
|
||||
// Warps with threadIdx.y % np != 0 must NOT return early.
|
||||
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
||||
|
||||
float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2;
|
||||
|
||||
float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
|
||||
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
|
||||
KQ_cm = meta_j[0];
|
||||
}
|
||||
|
||||
float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
|
||||
#pragma unroll
|
||||
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
|
||||
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
|
||||
}
|
||||
|
||||
const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
|
||||
float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
|
||||
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
|
||||
KQ_crs = KQ_cms*meta_j[1];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
|
||||
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
|
||||
}
|
||||
|
||||
// Write back combined meta data:
|
||||
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
|
||||
meta_j[0] = KQ_cmn; // Combined max. KQ values.
|
||||
meta_j[1] = KQ_crs; // Combined KQ rowsums.
|
||||
meta_j[2] = KQ_cms; // KQ max scales per parallel warp.
|
||||
}
|
||||
if (needs_fixup && threadIdx.x < mma_B::J) {
|
||||
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
||||
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
||||
}
|
||||
if (is_fixup && threadIdx.x < mma_B::J) {
|
||||
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
||||
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
||||
}
|
||||
}
|
||||
|
||||
if (np > 1) {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (np == 1 || threadIdx.y % np == 0) {
|
||||
// The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
|
||||
// The values after that are for the partial results of the individual blocks.
|
||||
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2));
|
||||
|
||||
#pragma unroll
|
||||
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
|
||||
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
||||
const int stride_j = WARP_SIZE / stride_k;
|
||||
|
||||
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
|
||||
const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J;
|
||||
|
||||
if (!is_fixup && jt*ncols + j_dst >= ne01) {
|
||||
continue;
|
||||
}
|
||||
const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2;
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
float2 dstk_val = make_float2(0.0f, 0.0f);
|
||||
#pragma unroll
|
||||
for (int ip = 0; ip < np; ++ip) {
|
||||
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2];
|
||||
const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]);
|
||||
dstk_val.x += dstk_val_add.x*KQ_crs;
|
||||
dstk_val.y += dstk_val_add.y*KQ_crs;
|
||||
}
|
||||
|
||||
if (!needs_fixup && !is_fixup) {
|
||||
const float KQ_rowsum_j = meta_j[1];
|
||||
dstk_val.x /= KQ_rowsum_j;
|
||||
dstk_val.y /= KQ_rowsum_j;
|
||||
}
|
||||
|
||||
if (is_fixup) {
|
||||
dstk_fixup_data[j_dst*(D/2) + k] = dstk_val;
|
||||
} else {
|
||||
dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (np > 1) {
|
||||
__syncthreads();
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap>
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride");
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
|
||||
const int iter_k = ne11 / KQ_stride;
|
||||
const int iter_j = (ne01 + (ncols - 1)) / ncols;
|
||||
|
||||
// kbc == k block continuous, current index in continuous ijk space.
|
||||
int kbc = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x;
|
||||
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x;
|
||||
|
||||
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
||||
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
||||
// In the most general case >2 seams can fall into the same tile.
|
||||
|
||||
// kb0 == k start index when in the output tile.
|
||||
int kb0_start = kbc % iter_k;
|
||||
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
||||
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
||||
const int channel = kbc / (iter_k*iter_j);
|
||||
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(D/2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
|
||||
|
||||
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||
if (kb0_start == 0) {
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
|
||||
jt, kb0_start, kb0_stop);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
|
||||
jt, kb0_start, kb0_stop);
|
||||
}
|
||||
|
||||
kbc += iter_k;
|
||||
kbc -= kbc % iter_k;
|
||||
|
||||
kb0_start = 0;
|
||||
kb0_stop = min(iter_k, kbc_stop - kbc);
|
||||
}
|
||||
|
||||
if (kbc >= kbc_stop) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int channel = kbc / (iter_k*iter_j);
|
||||
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(D/2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
|
||||
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
|
||||
jt, kb0_start, kb0_stop);
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block>
|
||||
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
typedef mma_A_I16K8<half2> mma_A;
|
||||
typedef mma_B_J8K8<half2> mma_B;
|
||||
|
||||
static_assert(D % mma_B::K == 0, "bad D");
|
||||
static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block");
|
||||
|
||||
const ggml_tensor * KQV = dst;
|
||||
|
||||
constexpr int KQ_stride = D <= 128 ? 64 : 32;
|
||||
constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
|
||||
cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8);
|
||||
constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half);
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, 0, KQ_stride>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
}
|
||||
|
||||
#define DECL_FATTN_MMA_F16_CASE(D, cols_per_block) \
|
||||
template void ggml_cuda_flash_attn_ext_mma_f16_case \
|
||||
<D, cols_per_block>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
|
||||
extern DECL_FATTN_MMA_F16_CASE( 64, 8);
|
||||
extern DECL_FATTN_MMA_F16_CASE( 80, 8);
|
||||
extern DECL_FATTN_MMA_F16_CASE( 96, 8);
|
||||
extern DECL_FATTN_MMA_F16_CASE(112, 8);
|
||||
extern DECL_FATTN_MMA_F16_CASE(128, 8);
|
||||
extern DECL_FATTN_MMA_F16_CASE(256, 8);
|
||||
|
||||
extern DECL_FATTN_MMA_F16_CASE( 64, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE( 80, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE( 96, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(112, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(128, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(256, 16);
|
||||
|
||||
extern DECL_FATTN_MMA_F16_CASE( 64, 32);
|
||||
extern DECL_FATTN_MMA_F16_CASE( 80, 32);
|
||||
extern DECL_FATTN_MMA_F16_CASE( 96, 32);
|
||||
extern DECL_FATTN_MMA_F16_CASE(112, 32);
|
||||
extern DECL_FATTN_MMA_F16_CASE(128, 32);
|
||||
extern DECL_FATTN_MMA_F16_CASE(256, 32);
|
||||
|
||||
extern DECL_FATTN_MMA_F16_CASE( 64, 64);
|
||||
extern DECL_FATTN_MMA_F16_CASE( 80, 64);
|
||||
extern DECL_FATTN_MMA_F16_CASE( 96, 64);
|
||||
extern DECL_FATTN_MMA_F16_CASE(112, 64);
|
||||
extern DECL_FATTN_MMA_F16_CASE(128, 64);
|
||||
extern DECL_FATTN_MMA_F16_CASE(256, 64);
|
||||
@@ -45,7 +45,17 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FP16_MMA_AVAILABLE
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
@@ -288,16 +298,18 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
|
||||
@@ -48,7 +48,12 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FP16_MMA_AVAILABLE
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
@@ -287,16 +292,18 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
|
||||
@@ -42,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -303,7 +309,8 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
|
||||
@@ -41,6 +41,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -284,7 +289,8 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
|
||||
648
ggml/src/ggml-cuda/fattn-wmma-f16.cu
Normal file
648
ggml/src/ggml-cuda/fattn-wmma-f16.cu
Normal file
@@ -0,0 +1,648 @@
|
||||
// Old and deprecated WMMA FlashAttention implementation.
|
||||
// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing.
|
||||
// Long-term the WMMA code should be replaced with a dedicated Volta implementation.
|
||||
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-wmma-f16.cuh"
|
||||
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
#include <mma.h>
|
||||
#endif // FP16_MMA_AVAILABLE
|
||||
|
||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
|
||||
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
||||
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
||||
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
||||
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
||||
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
||||
|
||||
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
||||
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
||||
static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
|
||||
|
||||
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
|
||||
constexpr int D_padded = D + 8;
|
||||
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
||||
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
||||
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
||||
|
||||
const int stride_Q = nb01 / sizeof(float);
|
||||
const int stride_KV = nb11 / sizeof(half);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
const half2 slope2 = make_half2(slopef, slopef);
|
||||
|
||||
const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
|
||||
|
||||
frag_b Q_b[D/16][ncols/frag_n];
|
||||
|
||||
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
||||
constexpr int mem_KQ = ncols*kqs_padded*kqar;
|
||||
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
|
||||
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
|
||||
float * KQ_f = (float *) KQ;
|
||||
half2 * KQ2 = (half2 *) KQ;
|
||||
|
||||
float KQ_rowsum_f[ncols/nwarps] = {0.0f};
|
||||
float KQ_max_f[ncols/nwarps];
|
||||
float KQ_max_scale_f[ncols/nwarps] = {0.0f};
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||
KQ_max_f[j] = -FLT_MAX/2.0f;
|
||||
}
|
||||
|
||||
half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
|
||||
half2 KQ_max_h2[ncols/nwarps];
|
||||
half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||
KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
|
||||
}
|
||||
|
||||
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
||||
half2 * VKQ2 = (half2 *) VKQ;
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
||||
break;
|
||||
}
|
||||
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert Q to half and apply scale, temporarily store in KQ:
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D && i >= D) {
|
||||
break;
|
||||
}
|
||||
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Load Q into tensor core fragments/registers since it will be used frequently:
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += 16) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Iterate over ne11 == previous tokens:
|
||||
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
|
||||
// Calculate tile of KQ:
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
||||
frag_c_KQ KQ_c[ncols/frag_n];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
||||
frag_a_K K_a;
|
||||
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Calculate softmax for each KQ column using the current max. value.
|
||||
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
||||
|
||||
if (use_logit_softcap) {
|
||||
KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
|
||||
}
|
||||
}
|
||||
|
||||
float KQ_max_new = KQ_max_f[j0/nwarps];
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
||||
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
|
||||
}
|
||||
KQ_max_new = warp_reduce_max(KQ_max_new);
|
||||
|
||||
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
|
||||
KQ_max_scale_f[j0/nwarps] = expf(diff);
|
||||
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
||||
KQ_max_scale_f[j0/nwarps] = 0.0f;
|
||||
}
|
||||
KQ_max_f[j0/nwarps] = KQ_max_new;
|
||||
|
||||
float KQ_rowsum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
|
||||
KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
|
||||
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
||||
KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
|
||||
}
|
||||
KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
|
||||
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
|
||||
}
|
||||
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
||||
|
||||
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
||||
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
|
||||
} else {
|
||||
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
||||
|
||||
if (use_logit_softcap) {
|
||||
// There is no dedicated tangens hyperbolicus function for half2.
|
||||
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
|
||||
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
|
||||
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
|
||||
|
||||
KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
|
||||
}
|
||||
}
|
||||
|
||||
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
||||
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
|
||||
}
|
||||
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
||||
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
||||
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
||||
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
||||
*((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
|
||||
KQ_max_h2[j0/nwarps] = KQ_max_new;
|
||||
|
||||
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
|
||||
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
|
||||
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
||||
*((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
|
||||
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
|
||||
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
|
||||
}
|
||||
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
||||
|
||||
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
||||
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
nvcuda::wmma::load_matrix_sync(
|
||||
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
||||
KQ + j0*(kqar*kqs_padded) + k,
|
||||
kqar*kqs_padded);
|
||||
}
|
||||
}
|
||||
|
||||
frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
|
||||
frag_a_V v_a;
|
||||
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
nvcuda::wmma::store_matrix_sync(
|
||||
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
||||
D_padded, nvcuda::wmma::mem_col_major);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
half2 VKQ_scale;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
|
||||
} else {
|
||||
VKQ_scale = KQ_max_scale_h2[j0/nwarps];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
||||
break;
|
||||
}
|
||||
|
||||
half2 VKQ_add = make_half2(0.0f, 0.0f);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < VKQ_ratio; ++l) {
|
||||
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
|
||||
}
|
||||
VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j_VKQ = j0 + threadIdx.y;
|
||||
if (ic0 + j_VKQ >= ne01) {
|
||||
return;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
|
||||
float KQ_rowsum_j;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
|
||||
} else {
|
||||
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D && i >= D) {
|
||||
break;
|
||||
}
|
||||
float dst_val = VKQ[j_VKQ*D_padded + i];
|
||||
if (parallel_blocks == 1) {
|
||||
dst_val /= KQ_rowsum_j;
|
||||
}
|
||||
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
|
||||
}
|
||||
|
||||
if (parallel_blocks == 1 || threadIdx.x != 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float2 dst_meta_val;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
dst_meta_val.x = KQ_max_f[j0/nwarps];
|
||||
} else {
|
||||
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
||||
}
|
||||
dst_meta_val.y = KQ_rowsum_j;
|
||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
}
|
||||
|
||||
constexpr int get_max_power_of_2(int x) {
|
||||
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
|
||||
}
|
||||
|
||||
static_assert(get_max_power_of_2(1) == 1, "Test failed.");
|
||||
static_assert(get_max_power_of_2(2) == 2, "Test failed.");
|
||||
static_assert(get_max_power_of_2(4) == 4, "Test failed.");
|
||||
static_assert(get_max_power_of_2(6) == 2, "Test failed.");
|
||||
|
||||
// Number of VKQ rows calculated in parallel:
|
||||
constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
|
||||
return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
|
||||
}
|
||||
|
||||
static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
|
||||
static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
|
||||
static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
||||
|
||||
template <int D, int cols_per_block, typename KQ_acc_t>
|
||||
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
constexpr int nwarps = 4;
|
||||
|
||||
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
|
||||
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (4*blocks_num_pb1 < 2*nsm) {
|
||||
constexpr int parallel_blocks = 4;
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
return;
|
||||
}
|
||||
if (2*blocks_num_pb1 < 2*nsm) {
|
||||
constexpr int parallel_blocks = 2;
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
return;
|
||||
}
|
||||
constexpr int parallel_blocks = 1;
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
if (prec != GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||
constexpr int cols_per_block = 16;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
constexpr int cols_per_block = 32;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
// case 256:
|
||||
// ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
|
||||
// break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
|
||||
constexpr int cols_per_block = 8;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
constexpr int cols_per_block = 16;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 32;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -1,543 +1,3 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
#include <mma.h>
|
||||
#endif // FP16_MMA_AVAILABLE
|
||||
|
||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
|
||||
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
||||
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
||||
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
||||
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
||||
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
||||
|
||||
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
||||
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
||||
static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
|
||||
|
||||
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
|
||||
constexpr int D_padded = D + 8;
|
||||
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
||||
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
||||
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
||||
|
||||
const int stride_Q = nb01 / sizeof(float);
|
||||
const int stride_KV = nb11 / sizeof(half);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
const half2 slope2 = make_half2(slopef, slopef);
|
||||
|
||||
const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
|
||||
|
||||
frag_b Q_b[D/16][ncols/frag_n];
|
||||
|
||||
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
||||
constexpr int mem_KQ = ncols*kqs_padded*kqar;
|
||||
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
|
||||
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
|
||||
float * KQ_f = (float *) KQ;
|
||||
half2 * KQ2 = (half2 *) KQ;
|
||||
|
||||
float KQ_rowsum_f[ncols/nwarps] = {0.0f};
|
||||
float KQ_max_f[ncols/nwarps];
|
||||
float KQ_max_scale_f[ncols/nwarps] = {0.0f};
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||
KQ_max_f[j] = -FLT_MAX/2.0f;
|
||||
}
|
||||
|
||||
half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
|
||||
half2 KQ_max_h2[ncols/nwarps];
|
||||
half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||
KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
|
||||
}
|
||||
|
||||
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
||||
half2 * VKQ2 = (half2 *) VKQ;
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
||||
break;
|
||||
}
|
||||
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert Q to half and apply scale, temporarily store in KQ:
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D && i >= D) {
|
||||
break;
|
||||
}
|
||||
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Load Q into tensor core fragments/registers since it will be used frequently:
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += 16) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Iterate over ne11 == previous tokens:
|
||||
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
|
||||
// Calculate tile of KQ:
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
||||
frag_c_KQ KQ_c[ncols/frag_n];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
||||
frag_a_K K_a;
|
||||
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Calculate softmax for each KQ column using the current max. value.
|
||||
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
||||
|
||||
if (use_logit_softcap) {
|
||||
KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
|
||||
}
|
||||
}
|
||||
|
||||
float KQ_max_new = KQ_max_f[j0/nwarps];
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
||||
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
|
||||
}
|
||||
KQ_max_new = warp_reduce_max(KQ_max_new);
|
||||
|
||||
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
|
||||
KQ_max_scale_f[j0/nwarps] = expf(diff);
|
||||
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
||||
KQ_max_scale_f[j0/nwarps] = 0.0f;
|
||||
}
|
||||
KQ_max_f[j0/nwarps] = KQ_max_new;
|
||||
|
||||
float KQ_rowsum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
|
||||
KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
|
||||
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
||||
KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
|
||||
}
|
||||
KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
|
||||
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
|
||||
}
|
||||
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
||||
|
||||
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
||||
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
|
||||
} else {
|
||||
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
||||
|
||||
if (use_logit_softcap) {
|
||||
// There is no dedicated tangens hyperbolicus function for half2.
|
||||
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
|
||||
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
|
||||
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
|
||||
|
||||
KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
|
||||
}
|
||||
}
|
||||
|
||||
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
||||
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
|
||||
}
|
||||
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
||||
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
||||
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
||||
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
||||
*((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
|
||||
KQ_max_h2[j0/nwarps] = KQ_max_new;
|
||||
|
||||
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
|
||||
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
|
||||
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
||||
*((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
|
||||
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
|
||||
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
|
||||
}
|
||||
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
||||
|
||||
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
||||
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
nvcuda::wmma::load_matrix_sync(
|
||||
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
||||
KQ + j0*(kqar*kqs_padded) + k,
|
||||
kqar*kqs_padded);
|
||||
}
|
||||
}
|
||||
|
||||
frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
|
||||
frag_a_V v_a;
|
||||
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
nvcuda::wmma::store_matrix_sync(
|
||||
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
||||
D_padded, nvcuda::wmma::mem_col_major);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
half2 VKQ_scale;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
|
||||
} else {
|
||||
VKQ_scale = KQ_max_scale_h2[j0/nwarps];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
||||
break;
|
||||
}
|
||||
|
||||
half2 VKQ_add = make_half2(0.0f, 0.0f);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < VKQ_ratio; ++l) {
|
||||
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
|
||||
}
|
||||
VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j_VKQ = j0 + threadIdx.y;
|
||||
if (ic0 + j_VKQ >= ne01) {
|
||||
return;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
|
||||
float KQ_rowsum_j;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
|
||||
} else {
|
||||
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D && i >= D) {
|
||||
break;
|
||||
}
|
||||
float dst_val = VKQ[j_VKQ*D_padded + i];
|
||||
if (parallel_blocks == 1) {
|
||||
dst_val /= KQ_rowsum_j;
|
||||
}
|
||||
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
|
||||
}
|
||||
|
||||
if (parallel_blocks == 1 || threadIdx.x != 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float2 dst_meta_val;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
dst_meta_val.x = KQ_max_f[j0/nwarps];
|
||||
} else {
|
||||
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
||||
}
|
||||
dst_meta_val.y = KQ_rowsum_j;
|
||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
constexpr int get_max_power_of_2(int x) {
|
||||
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
|
||||
}
|
||||
|
||||
static_assert(get_max_power_of_2(1) == 1, "Test failed.");
|
||||
static_assert(get_max_power_of_2(2) == 2, "Test failed.");
|
||||
static_assert(get_max_power_of_2(4) == 4, "Test failed.");
|
||||
static_assert(get_max_power_of_2(6) == 2, "Test failed.");
|
||||
|
||||
// Number of VKQ rows calculated in parallel:
|
||||
constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
|
||||
return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
|
||||
}
|
||||
|
||||
static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
|
||||
static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
|
||||
static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
||||
|
||||
template <int D, int cols_per_block, typename KQ_acc_t>
|
||||
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
constexpr int nwarps = 4;
|
||||
|
||||
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
|
||||
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (4*blocks_num_pb1 < 2*nsm) {
|
||||
constexpr int parallel_blocks = 4;
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
return;
|
||||
}
|
||||
if (2*blocks_num_pb1 < 2*nsm) {
|
||||
constexpr int parallel_blocks = 2;
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
return;
|
||||
}
|
||||
constexpr int parallel_blocks = 1;
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
}
|
||||
|
||||
#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
|
||||
template void ggml_cuda_flash_attn_ext_wmma_f16_case \
|
||||
<D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(112, 16, float);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(128, 16, float);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
|
||||
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float);
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float);
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(112, 32, float);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(128, 32, float);
|
||||
// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
|
||||
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 64, 8, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 96, 8, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(128, 8, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(256, 8, half);
|
||||
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(112, 16, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(128, 16, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
|
||||
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
|
||||
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-mma-f16.cuh"
|
||||
#include "fattn-tile-f16.cuh"
|
||||
#include "fattn-tile-f32.cuh"
|
||||
#include "fattn-vec-f16.cuh"
|
||||
@@ -7,144 +8,56 @@
|
||||
#include "fattn-wmma-f16.cuh"
|
||||
#include "fattn.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
template <int cols_per_block>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
if (prec != GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||
constexpr int cols_per_block = 16;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
constexpr int cols_per_block = 32;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
// case 256:
|
||||
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||
// break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
|
||||
constexpr int cols_per_block = 8;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
constexpr int cols_per_block = 16;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 32;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 16) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst);
|
||||
}
|
||||
|
||||
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
|
||||
@@ -323,10 +236,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
|
||||
if (!fp16_mma_available(cc)) {
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
||||
}
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -341,5 +262,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
||||
if (cc == GGML_CUDA_CC_VOLTA) {
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
||||
}
|
||||
|
||||
@@ -38,10 +38,12 @@
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv6.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <charconv>
|
||||
#include <cinttypes>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
@@ -62,7 +64,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||
[[noreturn]]
|
||||
void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
|
||||
int id = -1; // in case cudaGetDevice fails
|
||||
cudaGetDevice(&id);
|
||||
(void)cudaGetDevice(&id);
|
||||
|
||||
GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg);
|
||||
GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line);
|
||||
@@ -119,12 +121,78 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
static int ggml_cuda_parse_id(char devName[]) {
|
||||
// A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp
|
||||
// these values are not stable so this is susceptible to breakage
|
||||
// https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp
|
||||
int archMajor = 0x0;
|
||||
int archMinor = 0x0;
|
||||
int archNum = GGML_CUDA_CC_OFFSET_AMD;
|
||||
int archLen = strlen(devName);
|
||||
char archName[archLen + 1];
|
||||
|
||||
// strip leading 'gfx' while copying into our buffer
|
||||
if (archLen > 3) {
|
||||
strcpy(archName, &devName[3]);
|
||||
archLen -= 3;
|
||||
}
|
||||
|
||||
// trim trailing :xnack- or :sramecc- statuses
|
||||
archLen = strcspn(archName, ":");
|
||||
archName[archLen] = '\0';
|
||||
|
||||
// tease out the version information
|
||||
if (archLen > 8) {
|
||||
// versions labeled generic use '-' as delimiter
|
||||
// strip the trailing "-generic" then iterate through what remains
|
||||
if ((strstr(archName, "-generic"))) {
|
||||
archName[archLen - 8] = '\0';
|
||||
char * pch;
|
||||
if ((pch = strtok(archName, "-"))) {
|
||||
archMajor = (int)strtoul(pch, 0, 16);
|
||||
if ((pch = strtok(NULL, "-"))) {
|
||||
archMinor = 0x10 * (int)strtoul(pch, 0, 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (archLen >= 3) {
|
||||
// last two digits should be the minor * 0x10 + stepping
|
||||
archMinor = (int)strtoul(&archName[archLen - 2], 0, 16);
|
||||
archName[archLen - 2] = '\0';
|
||||
|
||||
// only the major version remains
|
||||
archMajor = (int)strtoul(archName, 0, 16);
|
||||
}
|
||||
archNum += archMajor * 0x100;
|
||||
archNum += archMinor;
|
||||
return archNum;
|
||||
}
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
static ggml_cuda_device_info ggml_cuda_init() {
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// Workaround for a rocBLAS bug when using multiple graphics cards:
|
||||
// https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
|
||||
rocblas_initialize();
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
{
|
||||
int major_version = 0;
|
||||
size_t version_length = 0;
|
||||
if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
|
||||
std::string version(version_length, '\0');
|
||||
if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
|
||||
version.resize(::strlen(version.c_str()));
|
||||
int parsed_value = 0;
|
||||
if (std::from_chars(version.c_str(), version.c_str() + version.length(), parsed_value).ec == std::errc()) {
|
||||
major_version = parsed_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (major_version < 4) {
|
||||
GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n");
|
||||
rocblas_initialize();
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
ggml_cuda_device_info info = {};
|
||||
@@ -152,7 +220,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
int device_vmm = 0;
|
||||
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
|
||||
#if defined(GGML_USE_VMM)
|
||||
CUdevice device;
|
||||
CU_CHECK(cuDeviceGet(&device, id));
|
||||
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
|
||||
@@ -164,24 +232,40 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
alloc_prop.location.id = id;
|
||||
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
||||
}
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
|
||||
#endif // defined(GGML_USE_VMM)
|
||||
info.devices[id].vmm = !!device_vmm;
|
||||
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
|
||||
|
||||
info.default_tensor_split[id] = total_vram;
|
||||
total_vram += prop.totalGlobalMem;
|
||||
|
||||
info.devices[id].nsm = prop.multiProcessorCount;
|
||||
info.devices[id].smpb = prop.sharedMemPerBlock;
|
||||
info.devices[id].nsm = prop.multiProcessorCount;
|
||||
info.devices[id].smpb = prop.sharedMemPerBlock;
|
||||
info.devices[id].warp_size = prop.warpSize;
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlock;
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD;
|
||||
|
||||
info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);
|
||||
if ((info.devices[id].cc & 0xff00) == 0x0) {
|
||||
GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s cc %d.%d\n",
|
||||
id, prop.name, prop.gcnArchName, prop.major, prop.minor);
|
||||
|
||||
// Fallback to prop.major and prop.minor
|
||||
if (prop.major > 0) {
|
||||
info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100;
|
||||
info.devices[id].cc += prop.minor * 0x10;
|
||||
}
|
||||
}
|
||||
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
|
||||
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
|
||||
device_vmm ? "yes" : "no", prop.warpSize);
|
||||
#else
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
}
|
||||
|
||||
@@ -300,7 +384,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
||||
};
|
||||
|
||||
// pool with virtual memory
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
|
||||
#if defined(GGML_USE_VMM)
|
||||
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
|
||||
|
||||
@@ -309,6 +393,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
size_t pool_used = 0;
|
||||
size_t pool_size = 0;
|
||||
size_t granularity;
|
||||
#if defined(GGML_USE_HIP)
|
||||
std::vector<std::pair<CUdeviceptr, size_t>> mappings;
|
||||
#endif
|
||||
|
||||
explicit ggml_cuda_pool_vmm(int device) :
|
||||
device(device),
|
||||
@@ -317,7 +404,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
|
||||
~ggml_cuda_pool_vmm() {
|
||||
if (pool_addr != 0) {
|
||||
#if defined(GGML_USE_HIP)
|
||||
// Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285
|
||||
for (std::pair<CUdeviceptr, size_t> & mapping : mappings) {
|
||||
CU_CHECK(cuMemUnmap(mapping.first, mapping.second));
|
||||
}
|
||||
#else
|
||||
CU_CHECK(cuMemUnmap(pool_addr, pool_size));
|
||||
#endif
|
||||
CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE));
|
||||
}
|
||||
}
|
||||
@@ -350,7 +444,11 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
}
|
||||
|
||||
// map at the end of the pool
|
||||
CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0));
|
||||
CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size);
|
||||
CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0));
|
||||
#if defined(GGML_USE_HIP)
|
||||
mappings.push_back({start_ptr, reserve_size});
|
||||
#endif
|
||||
|
||||
// the memory allocation handle is no longer needed after mapping
|
||||
CU_CHECK(cuMemRelease(handle));
|
||||
@@ -360,7 +458,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
access.location.id = device;
|
||||
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1));
|
||||
CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1));
|
||||
|
||||
// add to the pool
|
||||
pool_size += reserve_size;
|
||||
@@ -372,7 +470,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
|
||||
GGML_ASSERT(pool_addr != 0);
|
||||
|
||||
void * ptr = (void *) (pool_addr + pool_used);
|
||||
void * ptr = (void *) ((CUdeviceptr)((char *)(pool_addr) + pool_used));
|
||||
*actual_size = size;
|
||||
pool_used += size;
|
||||
|
||||
@@ -391,17 +489,17 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
pool_used -= size;
|
||||
|
||||
// all deallocations must be in reverse order of the allocations
|
||||
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
|
||||
GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));
|
||||
}
|
||||
};
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
|
||||
#endif // defined(GGML_USE_VMM)
|
||||
|
||||
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
|
||||
#if defined(GGML_USE_VMM)
|
||||
if (ggml_cuda_info().devices[device].vmm) {
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
|
||||
}
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
|
||||
#endif // defined(GGML_USE_VMM)
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
||||
}
|
||||
|
||||
@@ -547,7 +645,7 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
|
||||
cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
|
||||
return nullptr;
|
||||
}
|
||||
@@ -962,7 +1060,7 @@ static void * ggml_cuda_host_malloc(size_t size) {
|
||||
cudaError_t err = cudaMallocHost((void **) &ptr, size);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
|
||||
size / 1024.0 / 1024.0, cudaGetErrorString(err));
|
||||
return nullptr;
|
||||
@@ -1082,7 +1180,9 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
|
||||
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
||||
|
||||
if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
|
||||
|
||||
if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
|
||||
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
||||
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
|
||||
if (src0->type != GGML_TYPE_F16) {
|
||||
@@ -1103,28 +1203,38 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
|
||||
}
|
||||
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
|
||||
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
|
||||
|
||||
const half alpha_f16 = 1.0f;
|
||||
const half beta_f16 = 0.0f;
|
||||
|
||||
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
|
||||
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
|
||||
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||
}
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
|
||||
src1_ptr, CUDA_R_16F, ne10,
|
||||
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
|
||||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||
if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha, src0_ptr, CUDA_R_16F, ne00,
|
||||
src1_ptr, CUDA_R_16F, ne10,
|
||||
&beta, dst_dd_i, CUDA_R_32F, ldc,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
} else {
|
||||
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
|
||||
|
||||
const half alpha_f16 = 1.0f;
|
||||
const half beta_f16 = 0.0f;
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
|
||||
src1_ptr, CUDA_R_16F, ne10,
|
||||
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
|
||||
CUBLAS_COMPUTE_16F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||
}
|
||||
} else {
|
||||
ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
|
||||
ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id));
|
||||
@@ -1197,7 +1307,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
|
||||
CUDA_CHECK(err);
|
||||
} else {
|
||||
// reset the error
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
}
|
||||
} else {
|
||||
cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
|
||||
@@ -1205,7 +1315,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
|
||||
CUDA_CHECK(err);
|
||||
} else {
|
||||
// reset the error
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1256,8 +1366,6 @@ static void ggml_cuda_op_mul_mat(
|
||||
const int64_t ne13 = src1->ne[3];
|
||||
const int64_t nrows1 = ggml_nrows(src1);
|
||||
|
||||
GGML_ASSERT(ne03 == ne13);
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
|
||||
@@ -1271,9 +1379,11 @@ static void ggml_cuda_op_mul_mat(
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
|
||||
|
||||
GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne13 % ne03 == 0);
|
||||
|
||||
const int64_t i02_divisor = ne12 / ne02;
|
||||
const int64_t i03_divisor = ne13 / ne03;
|
||||
|
||||
const size_t src0_ts = ggml_type_size(src0->type);
|
||||
const size_t src0_bs = ggml_blck_size(src0->type);
|
||||
@@ -1289,6 +1399,7 @@ static void ggml_cuda_op_mul_mat(
|
||||
GGML_ASSERT(!(split && ne02 > 1));
|
||||
GGML_ASSERT(!(split && ne03 > 1));
|
||||
GGML_ASSERT(!(split && ne02 < ne12));
|
||||
GGML_ASSERT(!(split && ne03 < ne13));
|
||||
|
||||
ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;
|
||||
|
||||
@@ -1452,7 +1563,8 @@ static void ggml_cuda_op_mul_mat(
|
||||
}
|
||||
|
||||
// for split tensors the data begins at i0 == i0_offset_low
|
||||
char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
|
||||
const size_t nbytes_src0_matrix = ne01*ne00*src0_ts / src0_bs;
|
||||
char * src0_dd_i = dev[id].src0_dd + ((i03/i03_divisor)*ne02 + (i02/i02_divisor)) * nbytes_src0_matrix;
|
||||
float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
|
||||
char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset;
|
||||
float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
|
||||
@@ -1496,8 +1608,9 @@ static void ggml_cuda_op_mul_mat(
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
|
||||
if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) {
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
|
||||
src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
|
||||
}
|
||||
|
||||
// do the computation
|
||||
@@ -1613,10 +1726,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
|
||||
cudaDataType_t cu_data_type = CUDA_R_16F;
|
||||
|
||||
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
|
||||
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||
}
|
||||
|
||||
// dst strides
|
||||
size_t nbd2 = dst->nb[2];
|
||||
size_t nbd3 = dst->nb[3];
|
||||
@@ -1645,6 +1754,12 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
beta = &beta_f32;
|
||||
}
|
||||
|
||||
if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) {
|
||||
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||
alpha = &alpha_f32;
|
||||
beta = &beta_f32;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne13 % ne03 == 0);
|
||||
|
||||
@@ -1770,7 +1885,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||
|
||||
if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
|
||||
if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
|
||||
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
|
||||
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
|
||||
ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
|
||||
@@ -2104,12 +2219,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_rms_norm_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
||||
GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
|
||||
return false;
|
||||
} else {
|
||||
ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
|
||||
}
|
||||
ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
ggml_cuda_mul_mat_id(ctx, dst);
|
||||
@@ -2438,7 +2548,7 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto
|
||||
if (stat == cudaErrorInvalidDeviceFunction) {
|
||||
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
||||
// We don't need to update blas nodes, so clear error and move on.
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
} else {
|
||||
GGML_ASSERT(stat == cudaSuccess);
|
||||
}
|
||||
@@ -2493,14 +2603,20 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
|
||||
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
|
||||
cudaGraphExecUpdateResultInfo result_info;
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
hipGraphNode_t errorNode;
|
||||
hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
||||
#else
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||
#endif
|
||||
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
||||
#endif
|
||||
|
||||
// The pre-existing graph exec cannot be updated due to violated constraints
|
||||
// so instead clear error and re-instantiate
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
|
||||
cuda_ctx->cuda_graph->instance = nullptr;
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
@@ -2728,7 +2844,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
|
||||
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
|
||||
GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
|
||||
size / 1024.0 / 1024.0, cudaGetErrorString(err));
|
||||
@@ -2748,7 +2864,7 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
|
||||
cudaError_t err = cudaHostUnregister(buffer);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2880,9 +2996,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
#ifdef GGML_USE_MUSA
|
||||
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
|
||||
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
|
||||
@@ -3002,7 +3115,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||
} break;
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
|
||||
return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15);
|
||||
case GGML_OP_CONCAT:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
@@ -3022,6 +3135,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
break;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
return true;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
||||
break;
|
||||
@@ -3064,7 +3178,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARANGE:
|
||||
@@ -3216,7 +3332,7 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
|
||||
features.push_back({ "FORCE_CUBLAS", "1" });
|
||||
#endif
|
||||
|
||||
#ifdef GGML_CUDA_NO_VMM
|
||||
#ifndef GGML_USE_VMM
|
||||
features.push_back({ "NO_VMM", "1" });
|
||||
#endif
|
||||
|
||||
|
||||
@@ -1,11 +1,67 @@
|
||||
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
|
||||
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
|
||||
// The documentation for the PTX instructions can be found under:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
|
||||
//
|
||||
// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
|
||||
// A is a row-major matrix with shape I x K.
|
||||
// B is a column-major matrix with shape K x J.
|
||||
// C is a column-major matrix with shape I x J.
|
||||
// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements.
|
||||
// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
||||
// All matrix tiles have ne physical 32 bit elements per warp.
|
||||
//
|
||||
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
struct mma_int_A_I16K4 {
|
||||
|
||||
#if CUDART_VERSION >= 11800
|
||||
|
||||
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
||||
int ret = 0;
|
||||
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
||||
: "+r"(ret) : "r"(x));
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(NEW_MMA_AVAILABLE)
|
||||
return ret;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
||||
// Imagine transposing row-major matrix to column-major matrix.
|
||||
const int src_i_low = 2 * (threadIdx.x % 4);
|
||||
const int src_i_high = src_i_low + 1;
|
||||
const int src_j = threadIdx.x / 4;
|
||||
|
||||
const int src_laneid_low = src_i_low * 4 + src_j / 2;
|
||||
const int src_laneid_high = src_i_high * 4 + src_j / 2;
|
||||
|
||||
const int shift_low = ((src_j + 0) % 2) * 16;
|
||||
const int shift_high = ((src_j + 1) % 2) * 16;
|
||||
|
||||
const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
|
||||
const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
|
||||
|
||||
return ret_low | ret_high;
|
||||
}
|
||||
|
||||
#endif // CUDART_VERSION >= 11800
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct mma_A_I16K4 {
|
||||
static_assert(sizeof(T) == 4, "bad type size");
|
||||
|
||||
static constexpr int I = 16;
|
||||
static constexpr int K = 4;
|
||||
static constexpr int ne = 2;
|
||||
|
||||
int x[ne] = {0};
|
||||
T x[ne];
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = (l%2) * (I/2) + threadIdx.x / K;
|
||||
@@ -21,27 +77,35 @@ struct mma_int_A_I16K4 {
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
||||
#if defined(INT8_MMA_AVAILABLE)
|
||||
const int * xs = xs0 + (threadIdx.x%I)*stride;
|
||||
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||
: "+r"(x[0]), "+r"(x[1])
|
||||
: "l"(xs));
|
||||
#else
|
||||
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
x[l] = xs0[get_i(l)*stride + get_k(l)];
|
||||
}
|
||||
#endif // defined(INT8_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * xi = (int *) x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride;
|
||||
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||
: "+r"(xi[0]), "+r"(xi[1])
|
||||
: "l"(xs));
|
||||
#else
|
||||
load_generic(xs0, stride);
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_A_I16K8 {
|
||||
template <typename T>
|
||||
struct mma_A_I16K8 {
|
||||
static_assert(sizeof(T) == 4, "bad type size");
|
||||
|
||||
static constexpr int I = 16;
|
||||
static constexpr int K = 8;
|
||||
static constexpr int ne = 4;
|
||||
|
||||
int x[ne] = {0};
|
||||
T x[ne];
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
|
||||
@@ -57,31 +121,62 @@ struct mma_int_A_I16K8 {
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
||||
#if defined(INT8_MMA_AVAILABLE)
|
||||
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
||||
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
||||
: "l"(xs));
|
||||
#else
|
||||
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
x[l] = xs0[get_i(l)*stride + get_k(l)];
|
||||
}
|
||||
#endif // defined(INT8_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
|
||||
((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
|
||||
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * xi = (int * ) x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
||||
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
||||
: "l"(xs));
|
||||
#else
|
||||
GGML_UNUSED(xs0);
|
||||
GGML_UNUSED(stride);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * xi = (int * ) x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
||||
asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3])
|
||||
: "l"(xs));
|
||||
#else
|
||||
GGML_UNUSED(xs0);
|
||||
GGML_UNUSED(stride);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void transpose() {
|
||||
int * xi = (int *) x;
|
||||
xi[0] = ggml_cuda_movmatrix(xi[0]);
|
||||
|
||||
const int tmp = ggml_cuda_movmatrix(xi[1]);
|
||||
xi[1] = ggml_cuda_movmatrix(xi[2]);
|
||||
xi[2] = tmp;
|
||||
|
||||
xi[3] = ggml_cuda_movmatrix(xi[3]);
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_B_J8K4 {
|
||||
template <typename T>
|
||||
struct mma_B_J8K4 {
|
||||
static_assert(sizeof(T) == 4, "bad type size");
|
||||
|
||||
static constexpr int J = 8;
|
||||
static constexpr int K = 4;
|
||||
static constexpr int ne = 1;
|
||||
|
||||
int x[ne] = {0};
|
||||
T x[ne];
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int /* l */) {
|
||||
const int ret = threadIdx.x / K;
|
||||
@@ -97,27 +192,34 @@ struct mma_int_B_J8K4 {
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
||||
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
|
||||
const int * xs = xs0 + (threadIdx.x%J)*stride;
|
||||
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
|
||||
: "+r"(x[0])
|
||||
: "l"(xs));
|
||||
#else
|
||||
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
||||
}
|
||||
#endif // defined(INT8_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * xi = (int *) x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride;
|
||||
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
|
||||
: "+r"(xi[0]) : "l"(xs));
|
||||
#else
|
||||
load_generic(xs0, stride);
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_B_J8K8 {
|
||||
template <typename T>
|
||||
struct mma_B_J8K8 {
|
||||
static_assert(sizeof(T) == 4, "bad type size");
|
||||
|
||||
static constexpr int J = 8;
|
||||
static constexpr int K = 8;
|
||||
static constexpr int ne = 2;
|
||||
|
||||
int x[ne] = {0};
|
||||
T x[ne];
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int /* l */) {
|
||||
const int ret = threadIdx.x / (K/2);
|
||||
@@ -133,22 +235,31 @@ struct mma_int_B_J8K8 {
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
||||
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
|
||||
const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
|
||||
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||
: "+r"(x[0]), "+r"(x[1])
|
||||
: "l"(xs));
|
||||
#else
|
||||
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
||||
}
|
||||
#endif // defined(INT8_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * xi = (int *) x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
|
||||
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||
: "+r"(xi[0]), "+r"(xi[1])
|
||||
: "l"(xs));
|
||||
#else
|
||||
load_generic(xs0, stride);
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_C_I16J8 {
|
||||
template <typename T>
|
||||
struct mma_C_I16J8 {};
|
||||
|
||||
template <>
|
||||
struct mma_C_I16J8<int> {
|
||||
static constexpr int I = 16;
|
||||
static constexpr int J = 8;
|
||||
static constexpr int ne = 4;
|
||||
@@ -169,8 +280,8 @@ struct mma_int_C_I16J8 {
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
__device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
||||
@@ -188,11 +299,11 @@ struct mma_int_C_I16J8 {
|
||||
GGML_UNUSED(mma_A);
|
||||
GGML_UNUSED(mma_B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
__device__ __forceinline__ void mma(const mma_A_I16K8<int> & mma_A, const mma_B_J8K8<int> & mma_B) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
||||
@@ -216,6 +327,132 @@ struct mma_int_C_I16J8 {
|
||||
GGML_UNUSED(mma_A);
|
||||
GGML_UNUSED(mma_B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mma_C_I16J8<half2> {
|
||||
static constexpr int I = 16;
|
||||
static constexpr int J = 4;
|
||||
static constexpr int ne = 2;
|
||||
|
||||
half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = l * (I/2) + threadIdx.x / J;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < I);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int /* l */) {
|
||||
const int ret = threadIdx.x % J;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < J);
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * Axi = (int *) mma_A.x;
|
||||
int * Bxi = (int *) mma_B.x;
|
||||
int * xi = (int *) x;
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
||||
: "+r"(xi[0]), "+r"(xi[1])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(xi[0]), "+r"(xi[1])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(xi[0]), "+r"(xi[1])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(mma_A);
|
||||
GGML_UNUSED(mma_B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
|
||||
mma_B_J8K8<half2> mma_B;
|
||||
|
||||
int * xi = (int *) x;
|
||||
int * Bxi = (int *) mma_B.x;
|
||||
Bxi[0] = ggml_cuda_movmatrix(xi[0]);
|
||||
Bxi[1] = ggml_cuda_movmatrix(xi[1]);
|
||||
|
||||
return mma_B;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mma_C_I16J8<float> {
|
||||
static constexpr int I = 16;
|
||||
static constexpr int J = 8;
|
||||
static constexpr int ne = 4;
|
||||
|
||||
float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < I);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < J);
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * Axi = (int *) mma_A.x;
|
||||
int * Bxi = (int *) mma_B.x;
|
||||
int * xi = (int *) x;
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(mma_A);
|
||||
GGML_UNUSED(mma_B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
|
||||
mma_B_J8K8<half2> mma_B;
|
||||
mma_B.x[0] = make_half2(x[0], x[1]);
|
||||
mma_B.x[1] = make_half2(x[2], x[3]);
|
||||
|
||||
int * Bxi = (int *) mma_B.x;
|
||||
Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
|
||||
Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
|
||||
|
||||
return mma_B;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
x[l] = xs0[get_j(l)*stride + get_i(l)];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -132,7 +132,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (int8_mma_available(cc)) {
|
||||
if (new_mma_available(cc)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -148,5 +148,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
return (cc < GGML_CUDA_CC_RDNA3 && cc != GGML_CUDA_CC_CDNA && cc != GGML_CUDA_CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,25 +1,29 @@
|
||||
#include "ggml.h"
|
||||
#include "common.cuh"
|
||||
#include "mmv.cuh"
|
||||
|
||||
template <typename T, typename type_acc, int block_size>
|
||||
static __global__ void mul_mat_vec(
|
||||
const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
|
||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
|
||||
const int64_t row = blockIdx.x;
|
||||
const int64_t channel = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
|
||||
const int64_t row = blockIdx.x;
|
||||
const int64_t channel = blockIdx.y;
|
||||
const int64_t sample = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
|
||||
y += channel *stride_channel_y;
|
||||
dst += channel *stride_channel_dst;
|
||||
x += (sample/sample_ratio)*stride_sample_x + (channel/channel_ratio)*stride_channel_x + row*stride_row;
|
||||
y += sample *stride_sample_y + channel *stride_channel_y;
|
||||
dst += sample *stride_sample_dst + channel *stride_channel_dst;
|
||||
|
||||
const float2 * y2 = (const float2 *) y;
|
||||
|
||||
extern __shared__ char data_mmv[];
|
||||
float * buf_iw = (float *) data_mmv;
|
||||
|
||||
if (block_size > WARP_SIZE) {
|
||||
if (tid < WARP_SIZE) {
|
||||
if (block_size > warp_size) {
|
||||
if (tid < warp_size) {
|
||||
buf_iw[tid] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -67,16 +71,16 @@ static __global__ void mul_mat_vec(
|
||||
static_assert(std::is_same<T, void>::value, "unsupported type");
|
||||
}
|
||||
|
||||
sumf = warp_reduce_sum(sumf);
|
||||
sumf = warp_reduce_sum<warp_size>(sumf);
|
||||
|
||||
if (block_size > WARP_SIZE) {
|
||||
buf_iw[tid/WARP_SIZE] = sumf;
|
||||
if (block_size > warp_size) {
|
||||
buf_iw[tid/warp_size] = sumf;
|
||||
__syncthreads();
|
||||
if (tid >= WARP_SIZE) {
|
||||
if (tid >= warp_size) {
|
||||
return;
|
||||
}
|
||||
sumf = buf_iw[tid];
|
||||
sumf = warp_reduce_sum(sumf);
|
||||
sumf = warp_reduce_sum<warp_size>(sumf);
|
||||
}
|
||||
|
||||
if (tid != 0) {
|
||||
@@ -90,16 +94,28 @@ template <typename T, typename type_acc>
|
||||
static void launch_mul_mat_vec_cuda(
|
||||
const T * x, const float * y, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
GGML_ASSERT(stride_row % 2 == 0);
|
||||
GGML_ASSERT(nchannels_y % nchannels_x == 0);
|
||||
GGML_ASSERT(nsamples_y % nsamples_x == 0);
|
||||
const int64_t channel_ratio = nchannels_y / nchannels_x;
|
||||
const int64_t sample_ratio = nsamples_y / nsamples_x;
|
||||
int device;
|
||||
int warp_size;
|
||||
|
||||
int64_t block_size_best = WARP_SIZE;
|
||||
int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
|
||||
for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
|
||||
CUDA_CHECK(cudaGetDevice(&device));
|
||||
warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
|
||||
int64_t block_size_best = warp_size;
|
||||
int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
|
||||
int64_t max_block_size = 256;
|
||||
if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
|
||||
max_block_size = 128;
|
||||
}
|
||||
for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
|
||||
const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
|
||||
if (niter < niter_best) {
|
||||
niter_best = niter;
|
||||
@@ -107,41 +123,49 @@ static void launch_mul_mat_vec_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
const int smem = WARP_SIZE*sizeof(float);
|
||||
const dim3 block_nums(nrows, 1, nchannels_y);
|
||||
const int smem = warp_size*sizeof(float);
|
||||
const dim3 block_nums(nrows, nchannels_y, nsamples_y);
|
||||
const dim3 block_dims(block_size_best, 1, 1);
|
||||
switch (block_size_best) {
|
||||
case 32: {
|
||||
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 64: {
|
||||
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 96: {
|
||||
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 128: {
|
||||
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 160: {
|
||||
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 192: {
|
||||
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 224: {
|
||||
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 256: {
|
||||
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -153,16 +177,19 @@ template<typename T>
|
||||
static void mul_mat_vec_cuda(
|
||||
const T * x, const float * y, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
enum ggml_prec prec, cudaStream_t stream) {
|
||||
switch (prec) {
|
||||
case GGML_PREC_DEFAULT: {
|
||||
launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
|
||||
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
|
||||
launch_mul_mat_vec_cuda<T, half>
|
||||
(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case GGML_PREC_F32: {
|
||||
launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
|
||||
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
|
||||
launch_mul_mat_vec_cuda<T, float>
|
||||
(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
@@ -171,10 +198,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||
|
||||
GGML_ASSERT(src1->ne[1] == 1);
|
||||
const size_t ts_src0 = ggml_type_size(src0->type);
|
||||
const size_t ts_src1 = ggml_type_size(src1->type);
|
||||
const size_t ts_dst = ggml_type_size(dst->type);
|
||||
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
GGML_ASSERT(ne12 == ne2);
|
||||
GGML_ASSERT(ne13 == ne3);
|
||||
|
||||
GGML_ASSERT(nb00 == ts_src0);
|
||||
GGML_ASSERT(nb10 == ts_src1);
|
||||
GGML_ASSERT(nb0 == ts_dst);
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
||||
@@ -182,29 +218,22 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
GGML_ASSERT(dst->ne[2] == ne12);
|
||||
|
||||
GGML_ASSERT(src0->ne[3] == 1);
|
||||
GGML_ASSERT(src1->ne[3] == 1);
|
||||
GGML_ASSERT( dst->ne[3] == 1);
|
||||
|
||||
const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type);
|
||||
const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type);
|
||||
const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
|
||||
const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
|
||||
const int64_t s01 = src0->nb[1] / ts_src0;
|
||||
const int64_t s02 = src0->nb[2] / ts_src0;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s2 = dst->nb[2] / ts_dst;
|
||||
const int64_t s03 = src0->nb[3] / ts_src0;
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
const int64_t s3 = dst->nb[3] / ts_dst;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16: {
|
||||
const half * src0_d = (const half *) src0->data;
|
||||
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
|
||||
channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
|
||||
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
||||
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
|
||||
channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
|
||||
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
||||
@@ -233,20 +262,27 @@ void ggml_cuda_op_mul_mat_vec(
|
||||
const int64_t stride_row = ne00;
|
||||
const int64_t nchannels_x = 1;
|
||||
const int64_t nchannels_y = 1;
|
||||
const int64_t channel_stride_x = 0;
|
||||
const int64_t channel_stride_y = 0;
|
||||
const int64_t channel_stride_dst = 0;
|
||||
const int64_t stride_channel_x = 0;
|
||||
const int64_t stride_channel_y = 0;
|
||||
const int64_t stride_channel_dst = 0;
|
||||
const int64_t nsamples_x = 1;
|
||||
const int64_t nsamples_y = 1;
|
||||
const int64_t stride_sample_x = 0;
|
||||
const int64_t stride_sample_y = 0;
|
||||
const int64_t stride_sample_dst = 0;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16: {
|
||||
const half * src0_d = (const half *) src0_dd_i;
|
||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
|
||||
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
|
||||
nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
|
||||
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
|
||||
nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
||||
|
||||
@@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda(
|
||||
int64_t nwarps = 1;
|
||||
int64_t rows_per_cuda_block = 1;
|
||||
|
||||
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
|
||||
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
|
||||
switch(ncols_y) {
|
||||
case 1:
|
||||
nwarps = 4;
|
||||
@@ -166,6 +166,7 @@ static void mul_mat_vec_q_cuda(
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
|
||||
const dim3 block_nums(nblocks, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
#include "norm.cuh"
|
||||
#include <cstdint>
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
static __global__ void norm_f32(
|
||||
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps) {
|
||||
const int nrows = gridDim.x;
|
||||
const int nchannels = gridDim.y;
|
||||
|
||||
x += int64_t(row)*ncols;
|
||||
dst += int64_t(row)*ncols;
|
||||
const int row = blockIdx.x;
|
||||
const int channel = blockIdx.y;
|
||||
const int sample = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
|
||||
float2 mean_var = make_float2(0.0f, 0.0f);
|
||||
|
||||
@@ -97,12 +105,19 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
static __global__ void rms_norm_f32(
|
||||
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps) {
|
||||
const int nrows = gridDim.x;
|
||||
const int nchannels = gridDim.y;
|
||||
|
||||
x += int64_t(row)*ncols;
|
||||
dst += int64_t(row)*ncols;
|
||||
const int row = blockIdx.x;
|
||||
const int channel = blockIdx.y;
|
||||
const int sample = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
@@ -186,13 +201,16 @@ static __global__ void rms_norm_back_f32(
|
||||
}
|
||||
}
|
||||
|
||||
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
static void norm_f32_cuda(
|
||||
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,13 +225,16 @@ static void group_norm_f32_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
static void rms_norm_f32_cuda(
|
||||
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,23 +250,26 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
|
||||
|
||||
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
|
||||
const size_t ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(nb00 == ts0);
|
||||
const int64_t s01 = nb01 / ts0;
|
||||
const int64_t s02 = nb02 / ts0;
|
||||
const int64_t s03 = nb03 / ts0;
|
||||
|
||||
norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
@@ -254,8 +278,6 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
@@ -271,23 +293,26 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
|
||||
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
|
||||
const size_t ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(nb00 == ts0);
|
||||
const int64_t s01 = nb01 / ts0;
|
||||
const int64_t s02 = nb02 / ts0;
|
||||
const int64_t s03 = nb03 / ts0;
|
||||
|
||||
rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
@@ -34,6 +34,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(handle, stream));
|
||||
|
||||
const int64_t lda = nb01 / sizeof(float);
|
||||
const int64_t ldc = nb1 / sizeof(float);
|
||||
|
||||
const bool src1_T = ggml_is_transposed(src1);
|
||||
const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
||||
@@ -57,9 +60,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ne0));
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,12 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
|
||||
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||
#endif // __clang__
|
||||
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
||||
static __global__ void soft_max_f32(
|
||||
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
|
||||
@@ -118,6 +124,9 @@ static __global__ void soft_max_f32(
|
||||
dst[col] = vals[col] * inv_sum;
|
||||
}
|
||||
}
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
||||
static __global__ void soft_max_back_f32(
|
||||
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 32);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 64);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8);
|
||||
@@ -1,10 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-wmma-f16.cuh"
|
||||
|
||||
DECL_FATTN_WMMA_F16_CASE(64, 16, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(80, 16, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(96, 16, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(112, 16, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(128, 16, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(256, 16, float);
|
||||
@@ -1,9 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-wmma-f16.cuh"
|
||||
|
||||
DECL_FATTN_WMMA_F16_CASE(64, 32, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(80, 32, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(96, 32, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(112, 32, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(128, 32, float);
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user