mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-09 16:17:31 +03:00
Compare commits
93 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
009a113326 | ||
|
|
c8ac02fa1b | ||
|
|
4ef9301e4d | ||
|
|
ddf03c6d9a | ||
|
|
26229755c5 | ||
|
|
057dba336e | ||
|
|
501aeed18f | ||
|
|
0ec191e1d7 | ||
|
|
243532e556 | ||
|
|
5e9c635463 | ||
|
|
9949ad08f6 | ||
|
|
3ee9da0e4f | ||
|
|
75511a8d7e | ||
|
|
b54cb2e3d0 | ||
|
|
8a65a7a8ee | ||
|
|
8a132faaa0 | ||
|
|
4293919068 | ||
|
|
d12cc3d1ca | ||
|
|
2dcb7f74ed | ||
|
|
660600081f | ||
|
|
d9a12c82f0 | ||
|
|
4a05e0c566 | ||
|
|
e9fd96283d | ||
|
|
3ba12fed0a | ||
|
|
5473949070 | ||
|
|
dcdcbad42a | ||
|
|
5764d7c6a6 | ||
|
|
87f4744a80 | ||
|
|
85d482e6b6 | ||
|
|
ae65fbdf33 | ||
|
|
3bd9aa1f92 | ||
|
|
ece522f98c | ||
|
|
09343c0198 | ||
|
|
97508acb17 | ||
|
|
5c4aae66e1 | ||
|
|
c5ce4bc227 | ||
|
|
66c4f9ded0 | ||
|
|
93bdc61563 | ||
|
|
4eb19514dd | ||
|
|
957d717ce5 | ||
|
|
de1aa6fa73 | ||
|
|
69c28f1547 | ||
|
|
0d049d6a92 | ||
|
|
a8ec0df461 | ||
|
|
e8f5082697 | ||
|
|
22fc79134e | ||
|
|
2a619f6fbc | ||
|
|
edd4d9bca5 | ||
|
|
482192f12d | ||
|
|
71a81f6fcc | ||
|
|
ecce0087da | ||
|
|
d1f82e382d | ||
|
|
0988accf82 | ||
|
|
0033f53a07 | ||
|
|
d0a6dfeb28 | ||
|
|
2e1f0a889e | ||
|
|
506200cf8b | ||
|
|
15f786e658 | ||
|
|
94ca829b60 | ||
|
|
4aa962e2b0 | ||
|
|
941146b3f1 | ||
|
|
482d862bcb | ||
|
|
3979f2bb08 | ||
|
|
400ac8e194 | ||
|
|
f51fd36d79 | ||
|
|
25eec6f327 | ||
|
|
58190cc84d | ||
|
|
af76639f72 | ||
|
|
761797ffdf | ||
|
|
5d3a4a7da5 | ||
|
|
c08d28d088 | ||
|
|
661e9acb36 | ||
|
|
b8635075ff | ||
|
|
9c699074c9 | ||
|
|
d01f6274c0 | ||
|
|
650bf14eb9 | ||
|
|
b7ad48ebda | ||
|
|
d006858316 | ||
|
|
e439700992 | ||
|
|
50e0ad08fb | ||
|
|
f1f793ad06 | ||
|
|
af5c13841f | ||
|
|
277ff5fff7 | ||
|
|
384c0076bc | ||
|
|
1f34806c44 | ||
|
|
887535c33f | ||
|
|
d3416a4aa9 | ||
|
|
43a4ee4a2c | ||
|
|
f851fa5ab0 | ||
|
|
f1ac84119c | ||
|
|
b069b10ab4 | ||
|
|
0c58ba3365 | ||
|
|
57ace0d612 |
@@ -1,97 +0,0 @@
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG CUDA_VERSION=13.1.1
|
||||
# Target the CUDA build image
|
||||
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
FROM ${BASE_CUDA_DEV_CONTAINER} AS build
|
||||
|
||||
# CUDA architecture to build for (defaults to all supported archs)
|
||||
ARG CUDA_DOCKER_ARCH=default
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y gcc-14 g++-14 build-essential cmake python3 python3-pip git libssl-dev libgomp1
|
||||
|
||||
ENV CC=gcc-14 CXX=g++-14 CUDAHOSTCXX=g++-14
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
|
||||
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \
|
||||
fi && \
|
||||
cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_BUILD_TESTS=OFF ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
&& cp *.py /app/full \
|
||||
&& cp -r gguf-py /app/full \
|
||||
&& cp -r requirements /app/full \
|
||||
&& cp requirements.txt /app/full \
|
||||
&& cp .devops/tools.sh /app/full/tools.sh
|
||||
|
||||
## Base image
|
||||
FROM ${BASE_CUDA_RUN_CONTAINER} AS base
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
&& find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \
|
||||
&& find /var/cache -type f -delete
|
||||
|
||||
COPY --from=build /app/lib/ /app
|
||||
|
||||
### Full
|
||||
FROM base AS full
|
||||
|
||||
COPY --from=build /app/full /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
git \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-wheel \
|
||||
&& pip install --break-system-packages --upgrade setuptools \
|
||||
&& pip install --break-system-packages -r requirements.txt \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
&& find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \
|
||||
&& find /var/cache -type f -delete
|
||||
|
||||
|
||||
ENTRYPOINT ["/app/tools.sh"]
|
||||
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENTRYPOINT [ "/app/llama-cli" ]
|
||||
|
||||
### Server, Server only
|
||||
FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ]
|
||||
|
||||
ENTRYPOINT [ "/app/llama-server" ]
|
||||
@@ -16,7 +16,7 @@
|
||||
rocmPackages,
|
||||
vulkan-headers,
|
||||
vulkan-loader,
|
||||
curl,
|
||||
openssl,
|
||||
shaderc,
|
||||
useBlas ?
|
||||
builtins.all (x: !x) [
|
||||
@@ -160,7 +160,8 @@ effectiveStdenv.mkDerivation (finalAttrs: {
|
||||
++ optionals useMpi [ mpi ]
|
||||
++ optionals useRocm rocmBuildInputs
|
||||
++ optionals useBlas [ blas ]
|
||||
++ optionals useVulkan vulkanBuildInputs;
|
||||
++ optionals useVulkan vulkanBuildInputs
|
||||
++ [ openssl ];
|
||||
|
||||
cmakeFlags =
|
||||
[
|
||||
|
||||
8
.github/labeler.yml
vendored
8
.github/labeler.yml
vendored
@@ -73,10 +73,18 @@ android:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- examples/llama.android/**
|
||||
server/webui:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- tools/server/webui/**
|
||||
- tools/server/public/**
|
||||
server:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- tools/server/**
|
||||
|
||||
|
||||
|
||||
ggml:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
|
||||
38
.github/workflows/build-riscv.yml
vendored
38
.github/workflows/build-riscv.yml
vendored
@@ -35,7 +35,7 @@ env:
|
||||
|
||||
jobs:
|
||||
ubuntu-riscv64-native-sanitizer:
|
||||
runs-on: RISCV64
|
||||
runs-on: ubuntu-24.04-riscv
|
||||
|
||||
continue-on-error: true
|
||||
|
||||
@@ -50,17 +50,18 @@ jobs:
|
||||
sudo apt-get update
|
||||
|
||||
# Install necessary packages
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache git-lfs
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 cmake build-essential wget git-lfs
|
||||
|
||||
# Set gcc-14 and g++-14 as the default compilers
|
||||
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
|
||||
sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100
|
||||
sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc
|
||||
sudo ln -sf /usr/bin/g++-14 /usr/bin/g++
|
||||
|
||||
# Install Rust stable version
|
||||
rustup install stable
|
||||
rustup default stable
|
||||
if ! which rustc; then
|
||||
# Install Rust stable version
|
||||
sudo apt-get install -y rustup
|
||||
rustup install stable
|
||||
rustup default stable
|
||||
fi
|
||||
|
||||
git lfs install
|
||||
|
||||
@@ -73,23 +74,12 @@ jobs:
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
# Unique cache directory per matrix combination
|
||||
export CCACHE_DIR="$HOME/.ccache/sanitizer-${{ matrix.sanitizer }}-${{ matrix.build_type }}"
|
||||
mkdir -p "$CCACHE_DIR"
|
||||
|
||||
# Configure ccache
|
||||
ccache --set-config=max_size=5G
|
||||
ccache --set-config=compression=true
|
||||
ccache --set-config=compression_level=6
|
||||
ccache --set-config=cache_dir="$CCACHE_DIR"
|
||||
ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime
|
||||
ccache --set-config=hash_dir=false
|
||||
|
||||
# Export for subsequent steps
|
||||
echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV
|
||||
echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV
|
||||
# FIXME: Enable when ggml-org/ccache-action works on riscv64
|
||||
# - name: ccache
|
||||
# uses: ggml-org/ccache-action@v1.2.21
|
||||
# with:
|
||||
# key: ubuntu-riscv64-native-sanitizer-${{ matrix.sanytizer }}-${{ matrix.build_type }}
|
||||
# save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
|
||||
21
.github/workflows/build-self-hosted.yml
vendored
21
.github/workflows/build-self-hosted.yml
vendored
@@ -213,6 +213,27 @@ jobs:
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
|
||||
|
||||
ggml-ci-win-intel-vulkan:
|
||||
runs-on: [self-hosted, Windows, X64, Intel]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
shell: C:\msys64\usr\bin\bash.exe --noprofile --norc -eo pipefail "{0}"
|
||||
env:
|
||||
MSYSTEM: UCRT64
|
||||
CHERE_INVOKING: 1
|
||||
PATH: C:\msys64\ucrt64\bin;C:\msys64\usr\bin;C:\Windows\System32;${{ env.PATH }}
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
# Skip python related tests with GG_BUILD_LOW_PERF=1 since Windows MSYS2 UCRT64 currently fails to create
|
||||
# a valid python environment for testing
|
||||
LLAMA_FATAL_WARNINGS=OFF GG_BUILD_NINJA=1 GG_BUILD_VULKAN=1 GG_BUILD_LOW_PERF=1 ./ci/run.sh ./results/llama.cpp ./mnt/llama.cpp
|
||||
|
||||
ggml-ci-intel-openvino-gpu-low-perf:
|
||||
runs-on: [self-hosted, Linux, Intel, OpenVINO]
|
||||
|
||||
|
||||
2
.github/workflows/build-vulkan.yml
vendored
2
.github/workflows/build-vulkan.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
|
||||
- name: Setup Vulkan SDK
|
||||
if: steps.cache-sdk.outputs.cache-hit != 'true'
|
||||
uses: ./.github/actions/linux-setup-vulkan-llvmpipe
|
||||
uses: ./.github/actions/linux-setup-vulkan
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
version: ${{ env.VULKAN_SDK_VERSION }}
|
||||
|
||||
49
.github/workflows/build.yml
vendored
49
.github/workflows/build.yml
vendored
@@ -472,6 +472,7 @@ jobs:
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON \
|
||||
-DGPU_TARGETS="gfx1030" \
|
||||
-DGGML_HIP=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
@@ -990,11 +991,12 @@ jobs:
|
||||
-DROCM_DIR="${env:HIP_PATH}" `
|
||||
-DGGML_HIP=ON `
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON `
|
||||
-DGPU_TARGETS="gfx1100" `
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
|
||||
|
||||
ubuntu-cpu-riscv64-native:
|
||||
runs-on: RISCV64
|
||||
runs-on: ubuntu-24.04-riscv
|
||||
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
@@ -1002,24 +1004,21 @@ jobs:
|
||||
sudo apt-get update
|
||||
|
||||
# Install necessary packages
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache git-lfs
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 cmake build-essential libssl-dev wget git-lfs
|
||||
|
||||
# Set gcc-14 and g++-14 as the default compilers
|
||||
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
|
||||
sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100
|
||||
sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc
|
||||
sudo ln -sf /usr/bin/g++-14 /usr/bin/g++
|
||||
|
||||
# Install Rust stable version
|
||||
rustup install stable
|
||||
rustup default stable
|
||||
if ! which rustc; then
|
||||
# Install Rust stable version
|
||||
sudo apt-get install -y rustup
|
||||
rustup install stable
|
||||
rustup default stable
|
||||
fi
|
||||
|
||||
git lfs install
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Check environment
|
||||
run: |
|
||||
uname -a
|
||||
@@ -1029,25 +1028,17 @@ jobs:
|
||||
cmake --version
|
||||
rustc --version
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
# Set unique cache directory for this job
|
||||
export CCACHE_DIR="$HOME/.ccache/cpu-cmake-rv64-native"
|
||||
mkdir -p "$CCACHE_DIR"
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
# Configure ccache for optimal performance
|
||||
ccache --set-config=max_size=5G
|
||||
ccache --set-config=compression=true
|
||||
ccache --set-config=compression_level=6
|
||||
ccache --set-config=cache_dir="$CCACHE_DIR"
|
||||
|
||||
# Enable more aggressive caching
|
||||
ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime
|
||||
ccache --set-config=hash_dir=false
|
||||
|
||||
# Export for subsequent steps
|
||||
echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV
|
||||
echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV
|
||||
# FIXME: Enable when ggml-org/ccache-action works on riscv64
|
||||
# - name: ccache
|
||||
# uses: ggml-org/ccache-action@v1.2.21
|
||||
# with:
|
||||
# key: ubuntu-cpu-riscv64-native
|
||||
# evict-old-files: 1d
|
||||
# save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
|
||||
8
.github/workflows/docker.yml
vendored
8
.github/workflows/docker.yml
vendored
@@ -73,10 +73,10 @@ jobs:
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda-new.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda-new.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.1.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.1.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "musa", "dockerfile": ".devops/musa.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "intel", "dockerfile": ".devops/intel.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "vulkan", "dockerfile": ".devops/vulkan.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
|
||||
|
||||
2
.github/workflows/hip-quality-check.yml
vendored
2
.github/workflows/hip-quality-check.yml
vendored
@@ -59,7 +59,7 @@ jobs:
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGPU_TARGETS=gfx908 \
|
||||
-DGPU_TARGETS=gfx942 \
|
||||
-DGGML_HIP=ON \
|
||||
-DGGML_HIP_EXPORT_METRICS=Off \
|
||||
-DCMAKE_HIP_FLAGS="-Werror -Wno-tautological-compare" \
|
||||
|
||||
86
.github/workflows/release.yml
vendored
86
.github/workflows/release.yml
vendored
@@ -36,8 +36,26 @@ env:
|
||||
CMAKE_ARGS: "-DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_TOOLS=ON -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON"
|
||||
|
||||
jobs:
|
||||
macOS-arm64:
|
||||
runs-on: macos-14
|
||||
macOS-cpu:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- build: 'arm64'
|
||||
arch: 'arm64'
|
||||
os: macos-14
|
||||
defines: "-DGGML_METAL_USE_BF16=ON -DGGML_METAL_EMBED_LIBRARY=ON"
|
||||
- build: 'arm64-kleidiai'
|
||||
arch: 'arm64'
|
||||
os: macos-14
|
||||
defines: "-DGGML_METAL_USE_BF16=ON -DGGML_METAL_EMBED_LIBRARY=ON -DGGML_CPU_KLEIDIAI=ON"
|
||||
- build: 'x64'
|
||||
arch: 'x64'
|
||||
os: macos-15-intel
|
||||
# Metal is disabled on x64 due to intermittent failures with Github runners not having a GPU:
|
||||
# https://github.com/ggml-org/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313
|
||||
defines: "-DGGML_METAL=OFF -DCMAKE_OSX_DEPLOYMENT_TARGET=13.3"
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -49,7 +67,7 @@ jobs:
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: macOS-latest-arm64
|
||||
key: macOS-latest-${{ matrix.arch }}
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Build
|
||||
@@ -57,13 +75,11 @@ jobs:
|
||||
run: |
|
||||
sysctl -a
|
||||
cmake -B build \
|
||||
${{ matrix.defines }} \
|
||||
-DCMAKE_INSTALL_RPATH='@loader_path' \
|
||||
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_BUILD_BORINGSSL=ON \
|
||||
-DGGML_METAL_USE_BF16=ON \
|
||||
-DGGML_METAL_EMBED_LIBRARY=ON \
|
||||
-DGGML_RPC=ON \
|
||||
${{ env.CMAKE_ARGS }}
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
|
||||
@@ -75,61 +91,13 @@ jobs:
|
||||
id: pack_artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-${{ matrix.build }}.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz
|
||||
name: llama-bin-macos-arm64.tar.gz
|
||||
|
||||
macOS-x64:
|
||||
runs-on: macos-15-intel
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: macOS-latest-x64
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
sysctl -a
|
||||
# Metal is disabled due to intermittent failures with Github runners not having a GPU:
|
||||
# https://github.com/ggml-org/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313
|
||||
cmake -B build \
|
||||
-DCMAKE_INSTALL_RPATH='@loader_path' \
|
||||
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_BUILD_BORINGSSL=ON \
|
||||
-DGGML_METAL=OFF \
|
||||
-DGGML_RPC=ON \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=13.3
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
|
||||
- name: Determine tag name
|
||||
id: tag
|
||||
uses: ./.github/actions/get-tag-name
|
||||
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz
|
||||
name: llama-bin-macos-x64.tar.gz
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-macos-${{ matrix.build }}.tar.gz
|
||||
name: llama-bin-macos-${{ matrix.build }}.tar.gz
|
||||
|
||||
ubuntu-cpu:
|
||||
strategy:
|
||||
@@ -1003,8 +971,7 @@ jobs:
|
||||
- ubuntu-cpu
|
||||
- ubuntu-vulkan
|
||||
- ubuntu-24-openvino
|
||||
- macOS-arm64
|
||||
- macOS-x64
|
||||
- macOS-cpu
|
||||
- ios-xcode-build
|
||||
- openEuler-cann
|
||||
|
||||
@@ -1079,6 +1046,7 @@ jobs:
|
||||
|
||||
**macOS/iOS:**
|
||||
- [macOS Apple Silicon (arm64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz)
|
||||
- [macOS Apple Silicon (arm64, KleidiAI enabled)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-arm64-kleidiai.tar.gz)
|
||||
- [macOS Intel (x64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz)
|
||||
- [iOS XCFramework](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-xcframework.zip)
|
||||
|
||||
|
||||
37
ci/run.sh
37
ci/run.sh
@@ -119,6 +119,11 @@ if [ ! -z ${GG_BUILD_VULKAN} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=OFF -DGGML_BLAS=OFF"
|
||||
fi
|
||||
|
||||
# Build shared libs on Windows
|
||||
# to reduce binary size and avoid errors in library loading unit tests
|
||||
if uname -s | grep -qi nt; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DBUILD_SHARED_LIBS=ON"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_WEBGPU} ]; then
|
||||
@@ -221,7 +226,7 @@ function gg_run_ctest_debug {
|
||||
|
||||
set -e
|
||||
|
||||
# Check cmake and ctest are installed
|
||||
# Check required binaries are installed
|
||||
gg_check_build_requirements
|
||||
|
||||
(cmake -G "${CMAKE_GENERATOR}" -DCMAKE_BUILD_TYPE=Debug ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
|
||||
@@ -252,7 +257,7 @@ function gg_run_ctest_release {
|
||||
|
||||
set -e
|
||||
|
||||
# Check cmake and ctest are installed
|
||||
# Check required binaries are installed
|
||||
gg_check_build_requirements
|
||||
|
||||
(cmake -G "${CMAKE_GENERATOR}" -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
|
||||
@@ -627,10 +632,38 @@ function gg_sum_rerank_tiny {
|
||||
}
|
||||
|
||||
function gg_check_build_requirements {
|
||||
if ! command -v git &> /dev/null; then
|
||||
gg_printf 'git not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v git-lfs &> /dev/null; then
|
||||
gg_printf 'git-lfs not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v wget &> /dev/null; then
|
||||
gg_printf 'wget not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v python3 &> /dev/null; then
|
||||
gg_printf 'python3 not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v pip3 &> /dev/null; then
|
||||
gg_printf 'pip3 not found, please install'
|
||||
fi
|
||||
|
||||
if ! python3 -m ensurepip --help &> /dev/null; then
|
||||
gg_printf 'ensurepip not found, please install python3-venv package'
|
||||
fi
|
||||
|
||||
if ! command -v cmake &> /dev/null; then
|
||||
gg_printf 'cmake not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v ccache &> /dev/null; then
|
||||
gg_printf 'ccache not found, please consider installing for faster builds'
|
||||
fi
|
||||
|
||||
if ! command -v ctest &> /dev/null; then
|
||||
gg_printf 'ctest not found, please install'
|
||||
fi
|
||||
|
||||
@@ -1311,6 +1311,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.kv_unified = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
|
||||
add_opt(common_arg(
|
||||
{"--clear-idle"},
|
||||
{"--no-clear-idle"},
|
||||
"save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)",
|
||||
[](common_params & params, bool value) {
|
||||
params.clear_idle = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_CLEAR_IDLE").set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--context-shift"},
|
||||
{"--no-context-shift"},
|
||||
|
||||
@@ -6,110 +6,13 @@
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
namespace {
|
||||
|
||||
// Gemma4-specific PEG builder extending the standard chat builder.
|
||||
// Adds value type parsers that use <|\"|> as string delimiters
|
||||
// instead of JSON's double quotes, and disables json-to-schema
|
||||
// conversion for these types.
|
||||
class common_peg_gemma4_builder {
|
||||
common_chat_peg_builder & p_;
|
||||
static constexpr const char * QUOTE = "<|\"|>";
|
||||
|
||||
public:
|
||||
explicit common_peg_gemma4_builder(common_chat_peg_builder & p) : p_(p) {}
|
||||
|
||||
common_peg_parser gemma4_string() {
|
||||
return p_.rule("gemma4-string", [&]() {
|
||||
return p_.literal(QUOTE) + p_.until(QUOTE) + p_.literal(QUOTE);
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_number() {
|
||||
return p_.rule("gemma4-number", [&]() {
|
||||
auto digit1_9 = p_.chars("[1-9]", 1, 1);
|
||||
auto digits = p_.chars("[0-9]");
|
||||
auto int_part = p_.choice({p_.literal("0"), p_.sequence({digit1_9, p_.chars("[0-9]", 0, -1)})});
|
||||
auto frac = p_.sequence({p_.literal("."), digits});
|
||||
auto exp = p_.sequence({p_.choice({p_.literal("e"), p_.literal("E")}),
|
||||
p_.optional(p_.chars("[+-]", 1, 1)), digits});
|
||||
auto not_number_continuation = p_.negate(p_.chars("[0-9.eE+-]", 1, 1));
|
||||
return p_.sequence({p_.optional(p_.literal("-")), int_part, p_.optional(frac),
|
||||
p_.optional(exp), not_number_continuation});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_bool() {
|
||||
return p_.rule("gemma4-bool", [&]() {
|
||||
return p_.choice({p_.literal("true"), p_.literal("false")});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_null() {
|
||||
return p_.rule("gemma4-null", [&]() {
|
||||
return p_.literal("null");
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_dict() {
|
||||
return p_.rule("gemma4-dict", [&]() {
|
||||
auto ws = p_.space();
|
||||
auto key = p_.until(":");
|
||||
auto member = p_.sequence({key, p_.literal(":"), ws, gemma4_value()});
|
||||
auto members = p_.sequence({member, p_.zero_or_more(p_.sequence({p_.literal(","), ws, member}))});
|
||||
return p_.sequence({
|
||||
p_.literal("{"), ws,
|
||||
p_.choice({p_.literal("}"), p_.sequence({members, ws, p_.literal("}")})})
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_array() {
|
||||
return p_.rule("gemma4-array", [&]() {
|
||||
auto ws = p_.space();
|
||||
auto elements = p_.sequence({gemma4_value(), p_.zero_or_more(p_.sequence({p_.literal(","), ws, gemma4_value()}))});
|
||||
return p_.sequence({
|
||||
p_.literal("["), ws,
|
||||
p_.choice({p_.literal("]"), p_.sequence({elements, ws, p_.literal("]")})})
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_value() {
|
||||
return p_.rule("gemma4-value", [&]() {
|
||||
return p_.choice({gemma4_string(), gemma4_dict(), gemma4_array(),
|
||||
gemma4_number(), gemma4_bool(), gemma4_null()});
|
||||
});
|
||||
}
|
||||
|
||||
// Select the appropriate value parser based on JSON schema type.
|
||||
// Does NOT use schema() - the gemma4 types are pure PEG without
|
||||
// JSON schema metadata, so GBNF is generated directly from the
|
||||
// PEG structure.
|
||||
common_peg_parser gemma4_value_for_type(const json & schema) {
|
||||
if (!schema.contains("type") || !schema.at("type").is_string()) {
|
||||
return gemma4_value();
|
||||
}
|
||||
std::string type = schema.at("type").get<std::string>();
|
||||
if (type == "string") { return gemma4_string(); }
|
||||
if (type == "number") { return gemma4_number(); }
|
||||
if (type == "integer") { return gemma4_number(); }
|
||||
if (type == "boolean") { return gemma4_bool(); }
|
||||
if (type == "object") { return gemma4_dict(); }
|
||||
if (type == "array") { return gemma4_array(); }
|
||||
return gemma4_value();
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// Helper to iterate over tools/functions
|
||||
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
|
||||
for (const auto & tool : tools) {
|
||||
@@ -141,9 +44,7 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
||||
// Create the result structure
|
||||
common_chat_params data;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = (autoparser.tools.format.mode == tool_format::TAG_WITH_GEMMA4_DICT)
|
||||
? COMMON_CHAT_FORMAT_PEG_GEMMA4
|
||||
: COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = autoparser.preserved_tokens;
|
||||
|
||||
auto parser = autoparser.build_parser(inputs);
|
||||
@@ -270,8 +171,6 @@ common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const
|
||||
return build_tool_parser_tag_json(ctx);
|
||||
case tool_format::TAG_WITH_TAGGED:
|
||||
return build_tool_parser_tag_tagged(ctx);
|
||||
case tool_format::TAG_WITH_GEMMA4_DICT:
|
||||
return build_tool_parser_tag_gemma4_dict(ctx);
|
||||
default:
|
||||
LOG_ERR("[ERROR] Template seems to support tool calls, but failed to determine tool format. Tool calling will not work properly. "
|
||||
"Check for a fixed template for your model in the models/templates directory of your llama.cpp installation or "
|
||||
@@ -317,6 +216,44 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont
|
||||
p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p, const std::string & name,
|
||||
const common_peg_parser & call_id_section, bool have_call_id,
|
||||
const common_peg_parser & args,
|
||||
std::optional<common_peg_parser> atomic_peek) const {
|
||||
auto open = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix);
|
||||
bool matched_atomic = false;
|
||||
common_peg_parser func_parser = p.eps();
|
||||
|
||||
if (!function.name_suffix.empty()) {
|
||||
func_parser = open + call_id_section + p.space() + args;
|
||||
matched_atomic = true;
|
||||
} else if (have_call_id) {
|
||||
func_parser = p.atomic(open + call_id_section) + p.space() + args;
|
||||
matched_atomic = true;
|
||||
} else if (atomic_peek.has_value()) {
|
||||
func_parser = p.atomic(open + call_id_section + p.space() + *atomic_peek) + args;
|
||||
matched_atomic = true;
|
||||
} else {
|
||||
func_parser = open + call_id_section + p.space() + args;
|
||||
}
|
||||
|
||||
if (!function.close.empty()) {
|
||||
func_parser = func_parser + p.space() + p.tool_close(p.literal(function.close));
|
||||
} else if (!format.per_call_end.empty()) {
|
||||
// When there's no func_close but there is a per_call_end marker, use peek() to ensure
|
||||
// we only emit tool_close when we can actually see the closing marker. This prevents
|
||||
// premature closing during partial parsing when we've seen e.g. "</" which could be
|
||||
// either "</tool_call>" (end) or "<arg_key>" prefix that failed to match.
|
||||
func_parser = func_parser + p.tool_close(p.peek(p.literal(format.per_call_end)));
|
||||
} else {
|
||||
func_parser = func_parser + p.tool_close(p.space()); // force this to process tool closing callbacks in mapper
|
||||
}
|
||||
if (!matched_atomic) {
|
||||
func_parser = p.atomic(func_parser);
|
||||
}
|
||||
return func_parser;
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
@@ -330,17 +267,27 @@ common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context
|
||||
const auto & schema = func.contains("parameters") ? func.at("parameters") : json::object();
|
||||
|
||||
// Build call_id parser based on position (if supported)
|
||||
bool have_call_id = false;
|
||||
common_peg_parser call_id_section = p.eps();
|
||||
if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() &&
|
||||
!call_id.suffix.empty()) {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix;
|
||||
(!call_id.suffix.empty() || !arguments.start.empty())) {
|
||||
if (!call_id.suffix.empty()) {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix;
|
||||
} else {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(arguments.start)));
|
||||
}
|
||||
have_call_id = true;
|
||||
}
|
||||
auto args_parser = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema));
|
||||
if (!arguments.start.empty()) {
|
||||
args_parser = p.literal(arguments.start) + args_parser;
|
||||
}
|
||||
if (!arguments.end.empty()) {
|
||||
args_parser = args_parser + p.literal(arguments.end);
|
||||
}
|
||||
|
||||
auto func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema));
|
||||
if (!function.close.empty()) {
|
||||
func_parser = func_parser + function.close;
|
||||
}
|
||||
auto atomic_peek = !arguments.start.empty() ? std::optional(p.peek(p.literal(arguments.start))) : std::nullopt;
|
||||
auto func_parser = build_func_parser(p, name, call_id_section, have_call_id, args_parser, atomic_peek);
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
@@ -385,36 +332,36 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
auto until_suffix = p.rule("until-suffix", p.until(arguments.value_suffix));
|
||||
|
||||
common_peg_parser tool_choice = p.choice();
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & func = tool.at("function");
|
||||
std::string name = func.at("name");
|
||||
const auto & params = func.contains("parameters") ? func.at("parameters") : json::object();
|
||||
auto params = func.contains("parameters") ? func.at("parameters") : json::object();
|
||||
const auto & properties = params.contains("properties") ? params.at("properties") : json::object();
|
||||
|
||||
std::set<std::string> required;
|
||||
if (params.contains("required")) {
|
||||
params.at("required").get_to(required);
|
||||
}
|
||||
|
||||
auto schema_info = common_schema_info();
|
||||
schema_info.resolve_refs(params);
|
||||
|
||||
// Build parser for each argument, separating required and optional
|
||||
std::vector<common_peg_parser> required_parsers;
|
||||
std::vector<common_peg_parser> optional_parsers;
|
||||
for (const auto & [param_name, param_schema] : properties.items()) {
|
||||
bool is_required = required.find(param_name) != required.end();
|
||||
std::string type = "object";
|
||||
auto type_obj = param_schema.contains("type") ? param_schema.at("type") : json::object();
|
||||
if (type_obj.is_string()) {
|
||||
type_obj.get_to(type);
|
||||
} else if (type_obj.is_object()) {
|
||||
if (type_obj.contains("type") && type_obj.at("type").is_string()) {
|
||||
type_obj.at("type").get_to(type);
|
||||
}
|
||||
}
|
||||
bool is_required = required.find(param_name) != required.end();
|
||||
|
||||
auto arg =
|
||||
p.tool_arg(p.tool_arg_open(arguments.name_prefix + p.tool_arg_name(p.literal(param_name)) +
|
||||
arguments.name_suffix) +
|
||||
arguments.value_prefix +
|
||||
(type == "string" ?
|
||||
p.tool_arg_string_value(p.schema(p.until(arguments.value_suffix),
|
||||
(schema_info.resolves_to_string(param_schema) ?
|
||||
p.tool_arg_string_value(p.schema(until_suffix,
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema, true)) :
|
||||
p.tool_arg_json_value(p.schema(
|
||||
@@ -445,55 +392,34 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
for (const auto & opt : optional_parsers) {
|
||||
any_opt |= opt;
|
||||
}
|
||||
args_seq = args_seq + p.repeat(p.space() + any_opt, 0, (int) optional_parsers.size());
|
||||
args_seq = args_seq + p.repeat(p.space() + any_opt, 0, -1);
|
||||
}
|
||||
|
||||
if (!arguments.start.empty()) {
|
||||
args_seq = p.literal(arguments.start) + args_seq;
|
||||
}
|
||||
if (!arguments.end.empty()) {
|
||||
args_seq = args_seq + p.literal(arguments.end);
|
||||
}
|
||||
|
||||
// Build call_id parser based on position (if supported)
|
||||
common_peg_parser call_id_section = p.eps();
|
||||
bool have_call_id = false;
|
||||
if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() &&
|
||||
!call_id.suffix.empty()) {
|
||||
(!call_id.suffix.empty() || !arguments.start.empty())) {
|
||||
have_call_id = true;
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix)) + call_id.suffix);
|
||||
}
|
||||
|
||||
bool matched_atomic = false;
|
||||
common_peg_parser func_parser = p.eps();
|
||||
if (!function.name_suffix.empty()) {
|
||||
func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + args_seq;
|
||||
matched_atomic = true;
|
||||
} else if (have_call_id) {
|
||||
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section) + p.space() + args_seq;
|
||||
matched_atomic = true;
|
||||
} else if (!arguments.name_prefix.empty() && !required_parsers.empty()) {
|
||||
// Only peek for an arg tag when there are required args that must follow.
|
||||
// When all args are optional, the model may emit no arg tags at all (#20650).
|
||||
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq;
|
||||
matched_atomic = true;
|
||||
} else {
|
||||
func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + args_seq;
|
||||
}
|
||||
|
||||
if (!function.close.empty()) {
|
||||
func_parser = func_parser + p.space() + p.tool_close(p.literal(function.close));
|
||||
} else if (!format.per_call_end.empty()) {
|
||||
// When there's no func_close but there is a per_call_end marker, use peek() to ensure
|
||||
// we only emit tool_close when we can actually see the closing marker. This prevents
|
||||
// premature closing during partial parsing when we've seen e.g. "</" which could be
|
||||
// either "</tool_call>" (end) or "<arg_key>" prefix that failed to match.
|
||||
func_parser = func_parser + p.tool_close(p.peek(p.literal(format.per_call_end)));
|
||||
} else {
|
||||
func_parser =
|
||||
func_parser + p.tool_close(p.space()); // force this to process tool closing callbacks in mapper
|
||||
}
|
||||
if (!matched_atomic) {
|
||||
func_parser = p.atomic(func_parser);
|
||||
if (!call_id.suffix.empty()) {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix)) + call_id.suffix);
|
||||
} else {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(arguments.start)));
|
||||
}
|
||||
}
|
||||
|
||||
// Only peek for an arg tag when there are required args that must follow.
|
||||
// When all args are optional, the model may emit no arg tags at all (#20650).
|
||||
auto atomic_peek = (!arguments.name_prefix.empty() && !required_parsers.empty()) ?
|
||||
std::optional(p.peek(p.literal(arguments.name_prefix))) : std::nullopt;
|
||||
auto func_parser = build_func_parser(p, name, call_id_section, have_call_id, args_seq, atomic_peek);
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
@@ -536,121 +462,4 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
common_peg_gemma4_builder g4(p);
|
||||
static const std::string QUOTE = "<|\"|>";
|
||||
|
||||
common_peg_parser tool_choice = p.choice();
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & func = tool.at("function");
|
||||
std::string name = func.at("name");
|
||||
const auto & params = func.at("parameters");
|
||||
|
||||
if (!params.contains("properties") || !params.at("properties").is_object()) {
|
||||
auto func_parser = p.atomic(
|
||||
p.tool_open(p.literal(function.name_prefix) + p.tool_name(p.literal(name)) + p.literal("{")) +
|
||||
p.tool_args(p.eps()) +
|
||||
p.tool_close(p.literal("}")));
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto & properties = params.at("properties");
|
||||
std::set<std::string> required;
|
||||
if (params.contains("required") && params.at("required").is_array()) {
|
||||
params.at("required").get_to(required);
|
||||
}
|
||||
|
||||
// Build per-argument parsers, sorted alphabetically (matching template's dictsort)
|
||||
struct arg_entry {
|
||||
std::string param_name;
|
||||
common_peg_parser parser;
|
||||
};
|
||||
std::vector<arg_entry> arg_entries;
|
||||
|
||||
for (const auto & [param_name, param_schema] : properties.items()) {
|
||||
std::string type = "object";
|
||||
auto type_v = param_schema.contains("type") ? param_schema.at("type") : json::object();
|
||||
if (type_v.is_string()) type_v.get_to(type);
|
||||
|
||||
common_peg_parser value_parser = p.eps();
|
||||
if (type == "string") {
|
||||
// String values are delimited by <|"|>...<|"|>
|
||||
value_parser =
|
||||
p.literal(QUOTE) +
|
||||
p.tool_arg_string_value(p.schema(p.until(QUOTE),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema", param_schema, true)) +
|
||||
p.literal(QUOTE);
|
||||
} else if (type == "number" || type == "integer") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_number());
|
||||
} else if (type == "boolean") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_bool());
|
||||
} else if (type == "null") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_null());
|
||||
} else if (type == "object") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_dict());
|
||||
} else if (type == "array") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_array());
|
||||
} else {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_value());
|
||||
}
|
||||
|
||||
auto arg = p.tool_arg(
|
||||
p.tool_arg_open(p.tool_arg_name(p.literal(param_name)) + p.literal(":")) +
|
||||
value_parser +
|
||||
p.tool_arg_close(p.eps()));
|
||||
|
||||
arg_entries.push_back({param_name, p.rule("tool-" + name + "-arg-" + param_name, arg)});
|
||||
}
|
||||
|
||||
// Sort alphabetically to match Jinja's dictsort
|
||||
std::sort(arg_entries.begin(), arg_entries.end(), [](const auto & a, const auto & b) {
|
||||
return a.param_name < b.param_name;
|
||||
});
|
||||
|
||||
// Build arg sequence: any arg, then zero-or-more comma-separated additional args
|
||||
common_peg_parser args_seq = p.eps();
|
||||
if (!arg_entries.empty()) {
|
||||
common_peg_parser any_arg = p.choice();
|
||||
for (auto & entry : arg_entries) {
|
||||
any_arg |= entry.parser;
|
||||
}
|
||||
args_seq = p.optional(
|
||||
any_arg + p.repeat(p.literal(",") + any_arg, 0, (int) arg_entries.size() - 1));
|
||||
}
|
||||
|
||||
// Full parser: call:name{args}
|
||||
auto func_parser = p.atomic(
|
||||
p.tool_open(p.literal(function.name_prefix) + p.tool_name(p.literal(name)) + p.literal("{")) +
|
||||
p.tool_args(args_seq) +
|
||||
p.tool_close(p.literal("}")));
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
// Wrap each call in <|tool_call>...</tool_call|>
|
||||
auto wrapped_call = p.literal(format.per_call_start) + tool_choice + p.literal(format.per_call_end);
|
||||
|
||||
common_peg_parser tool_calls = p.eps();
|
||||
if (inputs.parallel_tool_calls) {
|
||||
tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call));
|
||||
} else {
|
||||
tool_calls = p.trigger_rule("tool-call", wrapped_call);
|
||||
}
|
||||
|
||||
if (!force_tools) {
|
||||
tool_calls = p.optional(tool_calls);
|
||||
}
|
||||
|
||||
auto content_before_tools = p.until_one_of({ format.per_call_start, ctx.reasoning->start });
|
||||
return ctx.reasoning_parser +
|
||||
(force_tools ? p.eps() : p.optional(p.content(content_before_tools) + p.optional(ctx.reasoning_parser))) +
|
||||
tool_calls + p.end();
|
||||
}
|
||||
|
||||
} // namespace autoparser
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "chat-auto-parser.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "common.h"
|
||||
#include "jinja/caps.h"
|
||||
#include "peg-parser.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include <chrono>
|
||||
#include <optional>
|
||||
@@ -144,7 +145,6 @@ enum class tool_format {
|
||||
JSON_NATIVE, // Pure JSON: {"name": "X", "arguments": {...}}
|
||||
TAG_WITH_JSON, // Tag-based with JSON args: <function=X>{...}</function>
|
||||
TAG_WITH_TAGGED, // Tag-based with tagged args: <param=key>value</param>
|
||||
TAG_WITH_GEMMA4_DICT, // Gemma4 custom dict: <|tool_call>call:name{key:<|"|>val<|"|>}<tool_call|>
|
||||
};
|
||||
|
||||
inline std::ostream & operator<<(std::ostream & os, const tool_format & format) {
|
||||
@@ -157,8 +157,6 @@ inline std::ostream & operator<<(std::ostream & os, const tool_format & format)
|
||||
return os << "TAG_WITH_JSON";
|
||||
case tool_format::TAG_WITH_TAGGED:
|
||||
return os << "TAG_WITH_TAGGED";
|
||||
case tool_format::TAG_WITH_GEMMA4_DICT:
|
||||
return os << "TAG_WITH_GEMMA4_DICT";
|
||||
default:
|
||||
return os << "UNKNOWN";
|
||||
}
|
||||
@@ -355,7 +353,13 @@ struct analyze_tools : analyze_base {
|
||||
common_peg_parser build_tool_parser_json_native(parser_build_context & ctx) const;
|
||||
common_peg_parser build_tool_parser_tag_json(parser_build_context & ctx) const;
|
||||
common_peg_parser build_tool_parser_tag_tagged(parser_build_context & ctx) const;
|
||||
common_peg_parser build_tool_parser_tag_gemma4_dict(parser_build_context & ctx) const;
|
||||
|
||||
// Shared helper: builds func_parser from open+call_id+args, handling atomic wrapping and close.
|
||||
// atomic_peek: if present, used as the peek expression in the third atomicity branch.
|
||||
common_peg_parser build_func_parser(common_chat_peg_builder & p, const std::string & name,
|
||||
const common_peg_parser & call_id_section, bool have_call_id,
|
||||
const common_peg_parser & args,
|
||||
std::optional<common_peg_parser> atomic_peek) const;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -25,6 +25,9 @@ static const std::string ARG_SECOND = "BB_ARG_SND_BB";
|
||||
static const std::string USER_MSG = "U_USER_MSG Hello END_U";
|
||||
static const std::string ASSISTANT_MSG = "A_ASST_MSG I can help END_A";
|
||||
static const std::string THINKING_CONTENT = "REASON_PART I am thinking END_R";
|
||||
static const std::string CALL_ID_001 = "call00001";
|
||||
static const std::string CALL_ID_002 = "call00002";
|
||||
static const std::string CALL_ID_999 = "call99999";
|
||||
|
||||
static std::vector<std::function<void(const common_chat_template & tmpl, autoparser &)>> workarounds(
|
||||
{ // Old reasoning Qwen templates - they don't really display reasoning content, but we still want to
|
||||
@@ -92,34 +95,6 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
|
||||
LOG_DBG(ANSI_ORANGE "[Patch: Functionary 3.1]\n" ANSI_RESET);
|
||||
}
|
||||
},
|
||||
// Gemma4 - custom dict format: <|tool_call>call:name{key:<|"|>val<|"|>}<tool_call|>
|
||||
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
|
||||
if (tmpl.src.find("'<|tool_call>call:'") != std::string::npos) {
|
||||
analysis.tools.format.mode = tool_format::TAG_WITH_GEMMA4_DICT;
|
||||
analysis.tools.format.per_call_start = "<|tool_call>";
|
||||
analysis.tools.format.per_call_end = "<tool_call|>";
|
||||
analysis.tools.format.section_start = "";
|
||||
analysis.tools.format.section_end = "";
|
||||
analysis.tools.function.name_prefix = "call:";
|
||||
analysis.tools.function.name_suffix = "";
|
||||
analysis.tools.arguments.start = "{";
|
||||
analysis.tools.arguments.end = "}";
|
||||
analysis.tools.arguments.name_prefix = "";
|
||||
analysis.tools.arguments.name_suffix = ":";
|
||||
analysis.tools.arguments.separator = ",";
|
||||
analysis.reasoning.mode = reasoning_mode::TAG_BASED;
|
||||
analysis.reasoning.start = "<|channel>thought";
|
||||
analysis.reasoning.end = "<channel|>";
|
||||
analysis.preserved_tokens.clear();
|
||||
analysis.preserved_tokens.push_back("<|tool_call>");
|
||||
analysis.preserved_tokens.push_back("<tool_call|>");
|
||||
analysis.preserved_tokens.push_back("<|tool_response>");
|
||||
analysis.preserved_tokens.push_back("<tool_response|>");
|
||||
analysis.preserved_tokens.push_back("<|\"|>");
|
||||
analysis.preserved_tokens.push_back("<|turn>");
|
||||
LOG_DBG(ANSI_ORANGE "[Patch: Gemma4]\n" ANSI_RESET);
|
||||
}
|
||||
},
|
||||
// DeepSeek-R1-Distill-Qwen
|
||||
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
|
||||
if (tmpl.src.find(
|
||||
@@ -131,6 +106,7 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
|
||||
analysis.tools.function.name_prefix = "<|tool▁sep|>";
|
||||
analysis.tools.format.per_call_end = "<|tool▁call▁end|>";
|
||||
analysis.tools.function.close = "```";
|
||||
LOG_DBG(ANSI_ORANGE "[Patch: DeepSeek-R1-Distill-Qwen]\n" ANSI_RESET);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -158,7 +134,7 @@ static json user_msg = json{
|
||||
{ "content", USER_MSG }
|
||||
};
|
||||
|
||||
static json build_tool_call(const std::string & name, const json & args, const std::string & id = "call00001") {
|
||||
static json build_tool_call(const std::string & name, const json & args, const std::string & id = CALL_ID_001) {
|
||||
return json{
|
||||
{ "id", id },
|
||||
{ "type", "function" },
|
||||
@@ -166,17 +142,17 @@ static json build_tool_call(const std::string & name, const json & args, const s
|
||||
};
|
||||
}
|
||||
|
||||
static json first_tool_call_zero_args = build_tool_call(FUN_FIRST, json::object(), "call00001");
|
||||
static json first_tool_call_one_arg = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "XXXX" }}, "call00001");
|
||||
static json first_tool_call_one_arg_other_val = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "YYYY" }}, "call00001");
|
||||
static json first_tool_call_other_arg = build_tool_call(FUN_FIRST, {{ ARG_SECOND, "YYYY" }}, "call00001");
|
||||
static json first_tool_call_zero_args = build_tool_call(FUN_FIRST, json::object(), CALL_ID_001);
|
||||
static json first_tool_call_one_arg = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "XXXX" }}, CALL_ID_001);
|
||||
static json first_tool_call_one_arg_other_val = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "YYYY" }}, CALL_ID_001);
|
||||
static json first_tool_call_other_arg = build_tool_call(FUN_FIRST, {{ ARG_SECOND, "YYYY" }}, CALL_ID_001);
|
||||
|
||||
static json first_tool_call =
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call00001");
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_001);
|
||||
static json second_tool_call =
|
||||
build_tool_call(FUN_SECOND, json{ { ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call00002");
|
||||
build_tool_call(FUN_SECOND, json{ { ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_002);
|
||||
static json first_tool_call_alt_id =
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call99999");
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_999);
|
||||
|
||||
template <typename T>
|
||||
static std::string mode_to_str(T mode) {
|
||||
@@ -215,6 +191,11 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
|
||||
LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str());
|
||||
LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str());
|
||||
LOG_DBG("func_close: '%s'\n", tools.function.close.c_str());
|
||||
LOG_DBG("call_id_prefix: '%s'\n", tools.call_id.prefix.c_str());
|
||||
LOG_DBG("call_id_suffix: '%s'\n", tools.call_id.suffix.c_str());
|
||||
LOG_DBG("call_id_pos: '%s'\n", mode_to_str(tools.call_id.pos).c_str());
|
||||
LOG_DBG("args_start: '%s'\n", tools.arguments.start.c_str());
|
||||
LOG_DBG("args_end: '%s'\n", tools.arguments.end.c_str());
|
||||
LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str());
|
||||
LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str());
|
||||
LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str());
|
||||
@@ -583,12 +564,15 @@ analyze_tools::analyze_tools(const common_chat_template & tmpl,
|
||||
if (caps.supports_parallel_tool_calls) {
|
||||
check_per_call_markers();
|
||||
}
|
||||
LOG_DBG(ANSI_ORANGE "Phase 3a: Function call analysis\n" ANSI_RESET);
|
||||
extract_function_markers();
|
||||
LOG_DBG(ANSI_ORANGE "Phase 3b: Argument analysis\n" ANSI_RESET);
|
||||
if (format.mode == tool_format::TAG_WITH_TAGGED) {
|
||||
analyze_arguments();
|
||||
}
|
||||
extract_argument_separator();
|
||||
extract_args_markers();
|
||||
LOG_DBG(ANSI_ORANGE "Phase 3c: Call id analysis\n" ANSI_RESET);
|
||||
extract_call_id_markers();
|
||||
}
|
||||
}
|
||||
@@ -979,8 +963,6 @@ void analyze_tools::extract_function_markers() {
|
||||
}
|
||||
|
||||
void analyze_tools::analyze_arguments() {
|
||||
LOG_DBG(ANSI_ORANGE "Phase 4: Argument analysis\n" ANSI_RESET);
|
||||
|
||||
extract_argument_name_markers();
|
||||
extract_argument_value_markers();
|
||||
}
|
||||
@@ -1189,7 +1171,7 @@ void analyze_tools::extract_args_markers() {
|
||||
|
||||
const auto & diff = comparison->diff;
|
||||
|
||||
if (format.mode != tool_format::JSON_NATIVE) {
|
||||
if (format.mode == tool_format::JSON_NATIVE) {
|
||||
std::string prefix_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
|
||||
std::string suffix_marker = !format.section_end.empty() ? format.section_end : format.per_call_end;
|
||||
// these might happen earlier in the tools section as an example or somewhere else, so we need to find the closest ones
|
||||
@@ -1211,6 +1193,10 @@ void analyze_tools::extract_args_markers() {
|
||||
if (find_fun != std::string::npos) {
|
||||
args_start = args_start.substr(find_fun + FUN_FIRST.size(), args_start.size() - find_fun - FUN_FIRST.size());
|
||||
}
|
||||
size_t find_call_id = args_start.find(CALL_ID_001);
|
||||
if (find_call_id != std::string::npos) {
|
||||
args_start = args_start.substr(find_call_id + CALL_ID_001.size(), args_start.size() - find_call_id - CALL_ID_001.size());
|
||||
}
|
||||
arguments.start = args_start;
|
||||
arguments.end = args_end;
|
||||
}
|
||||
@@ -1250,8 +1236,8 @@ void analyze_tools::extract_call_id_markers() {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string id_value_1 = "call00001";
|
||||
std::string id_value_2 = "call99999";
|
||||
std::string id_value_1 = CALL_ID_001;
|
||||
std::string id_value_2 = CALL_ID_999;
|
||||
|
||||
size_t common_id_prefix_len = 0;
|
||||
for (size_t i = 0; i < std::min(id_value_1.length(), id_value_2.length()); i++) {
|
||||
@@ -1350,6 +1336,14 @@ void analyze_tools::extract_call_id_markers() {
|
||||
call_id.suffix = find_first_marker(before_func);
|
||||
}
|
||||
|
||||
if (call_id.prefix == arguments.end) {
|
||||
call_id.prefix = "";
|
||||
}
|
||||
|
||||
if (call_id.suffix == arguments.start) {
|
||||
call_id.suffix = "";
|
||||
}
|
||||
|
||||
// When call_id is detected, per_call_end may have been incorrectly set to include
|
||||
// the call_id_suffix and sample args. Clear it if it starts with call_id_suffix.
|
||||
if (call_id.pos != call_id_position::NONE && !call_id.suffix.empty() &&
|
||||
|
||||
@@ -75,84 +75,6 @@ static std::string escape_json_string_inner(const std::string & s) {
|
||||
return escaped;
|
||||
}
|
||||
|
||||
static const std::string GEMMA4_QUOTE = "<|\"|>";
|
||||
|
||||
static std::string normalize_gemma4_to_json(const std::string & input) {
|
||||
std::string result;
|
||||
result.reserve(input.size() * 2);
|
||||
|
||||
enum Ctx { DICT, ARRAY };
|
||||
std::vector<Ctx> ctx;
|
||||
|
||||
auto is_ws = [](char c) { return c == ' ' || c == '\t' || c == '\n' || c == '\r'; };
|
||||
auto skip_ws = [&](size_t & pos) {
|
||||
while (pos < input.size() && is_ws(input[pos])) {
|
||||
result += input[pos++];
|
||||
}
|
||||
};
|
||||
|
||||
auto quote_unquoted_key = [&](size_t & pos) {
|
||||
if (pos < input.size() && input[pos] != '"' && input[pos] != '}') {
|
||||
result += '"';
|
||||
while (pos < input.size() && input[pos] != ':' && !is_ws(input[pos])) {
|
||||
result += input[pos++];
|
||||
}
|
||||
result += '"';
|
||||
skip_ws(pos);
|
||||
}
|
||||
};
|
||||
|
||||
size_t i = 0;
|
||||
while (i < input.size()) {
|
||||
if (i + GEMMA4_QUOTE.size() <= input.size() &&
|
||||
input.compare(i, GEMMA4_QUOTE.size(), GEMMA4_QUOTE) == 0) {
|
||||
result += '"';
|
||||
i += GEMMA4_QUOTE.size();
|
||||
continue;
|
||||
}
|
||||
|
||||
char c = input[i];
|
||||
|
||||
if (c == '{') {
|
||||
result += c;
|
||||
ctx.push_back(DICT);
|
||||
++i;
|
||||
skip_ws(i);
|
||||
quote_unquoted_key(i);
|
||||
continue;
|
||||
}
|
||||
if (c == '}') {
|
||||
result += c;
|
||||
if (!ctx.empty()) ctx.pop_back();
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
if (c == '[') {
|
||||
result += c;
|
||||
ctx.push_back(ARRAY);
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
if (c == ']') {
|
||||
result += c;
|
||||
if (!ctx.empty()) ctx.pop_back();
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
if (c == ',' && !ctx.empty() && ctx.back() == DICT) {
|
||||
result += c;
|
||||
++i;
|
||||
skip_ws(i);
|
||||
quote_unquoted_key(i);
|
||||
continue;
|
||||
}
|
||||
|
||||
result += c;
|
||||
++i;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Convert Python-style single-quoted strings to JSON double-quoted strings
|
||||
// Only converts outer string delimiters, properly handling escape sequences:
|
||||
// - {'key': 'value'} -> {"key": "value"}
|
||||
@@ -296,10 +218,6 @@ std::string common_chat_peg_mapper::normalize_container_value(const std::string
|
||||
return normalize_quotes_to_json(input);
|
||||
}
|
||||
|
||||
std::string common_chat_peg_gemma4_mapper::normalize_container_value(const std::string & input) {
|
||||
return normalize_quotes_to_json(normalize_gemma4_to_json(input));
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena,
|
||||
const common_peg_parse_result & parse_result_arg) {
|
||||
arena.visit(parse_result_arg, [this](const common_peg_ast_node & node) { map(node); });
|
||||
@@ -947,3 +865,143 @@ common_peg_parser common_chat_peg_builder::standard_json_tools(
|
||||
|
||||
return force_tool_calls ? section : optional(section);
|
||||
}
|
||||
|
||||
void common_chat_peg_gemma4_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) {
|
||||
for (const auto & node : result.nodes) {
|
||||
visit(arena, node);
|
||||
}
|
||||
}
|
||||
|
||||
static std::string gemma4_to_json(const common_peg_ast_arena & arena, common_peg_ast_id id) {
|
||||
const auto & node = arena.get(id);
|
||||
|
||||
if (node.text.empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-number" || node.rule == "gemma4-bool" || node.rule == "gemma4-null") {
|
||||
return std::string(node.text);
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-string-content") {
|
||||
return escape_json_string_inner(std::string(node.text));
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-string") {
|
||||
std::string result = "\"";
|
||||
if (!node.children.empty()) {
|
||||
result += gemma4_to_json(arena, node.children[0]);
|
||||
if (!node.is_partial) {
|
||||
result += "\"";
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-array") {
|
||||
std::string result = "[";
|
||||
|
||||
bool add_comma = false;
|
||||
for (auto child_id : node.children) {
|
||||
if (add_comma) {
|
||||
result += ',';
|
||||
}
|
||||
add_comma = true;
|
||||
result += gemma4_to_json(arena, child_id);
|
||||
}
|
||||
|
||||
if (!node.is_partial) {
|
||||
result += ']';
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-dict-key-name") {
|
||||
return std::string(node.text);
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-dict-key") {
|
||||
std::string result = "\"";
|
||||
if (!node.children.empty()) {
|
||||
result += escape_json_string_inner(gemma4_to_json(arena, node.children[0]));
|
||||
}
|
||||
if (!node.is_partial) {
|
||||
result += "\":";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-dict-kv") {
|
||||
std::string result;
|
||||
for (auto child_id : node.children) {
|
||||
result += gemma4_to_json(arena, child_id);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-dict") {
|
||||
std::string result = "{";
|
||||
|
||||
bool add_comma = false;
|
||||
for (auto child_id : node.children) {
|
||||
if (add_comma) {
|
||||
result += ',';
|
||||
}
|
||||
add_comma = true;
|
||||
result += gemma4_to_json(arena, child_id);
|
||||
}
|
||||
|
||||
if (!node.is_partial) {
|
||||
result += '}';
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-value") {
|
||||
if (!node.children.empty()) {
|
||||
return gemma4_to_json(arena, node.children[0]);
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
return "";
|
||||
}
|
||||
|
||||
void common_chat_peg_gemma4_mapper::visit(const common_peg_ast_arena & arena, common_peg_ast_id id) {
|
||||
const auto & node = arena.get(id);
|
||||
|
||||
if (node.tag == "reasoning") {
|
||||
result.reasoning_content += std::string(node.text);
|
||||
return;
|
||||
}
|
||||
|
||||
if (node.tag == "content") {
|
||||
result.content += std::string(node.text);
|
||||
return;
|
||||
}
|
||||
|
||||
if (node.tag == "tool") {
|
||||
auto name_id = arena.find_by_tag(node, "tool-name");
|
||||
auto args_id = arena.find_by_tag(node, "tool-args");
|
||||
|
||||
if (name_id != COMMON_PEG_INVALID_AST_ID && args_id != COMMON_PEG_INVALID_AST_ID) {
|
||||
const auto & name_node = arena.get(name_id);
|
||||
const auto & args_node = arena.get(args_id);
|
||||
|
||||
if (!name_node.is_partial) {
|
||||
common_chat_tool_call call;
|
||||
call.name = std::string(name_node.text);
|
||||
if (!args_node.children.empty()) {
|
||||
call.arguments = gemma4_to_json(arena, args_node.children[0]);
|
||||
}
|
||||
result.tool_calls.push_back(call);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto child_id : node.children) {
|
||||
visit(arena, child_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,8 +35,9 @@ class common_chat_peg_mapper {
|
||||
class common_chat_peg_gemma4_mapper : public common_chat_peg_mapper {
|
||||
public:
|
||||
common_chat_peg_gemma4_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
||||
protected:
|
||||
std::string normalize_container_value(const std::string & input) override;
|
||||
virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
|
||||
private:
|
||||
void visit(const common_peg_ast_arena & arena, common_peg_ast_id id);
|
||||
};
|
||||
|
||||
struct content_structure;
|
||||
|
||||
345
common/chat.cpp
345
common/chat.cpp
@@ -13,6 +13,8 @@
|
||||
#include "jinja/caps.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <ctime>
|
||||
@@ -762,12 +764,12 @@ static void foreach_parameter(const json &
|
||||
}
|
||||
}
|
||||
|
||||
std::string common_chat_template_direct_apply(
|
||||
static std::string common_chat_template_direct_apply_impl(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override,
|
||||
const std::optional<json> & tools_override,
|
||||
const std::optional<json> & additional_context) {
|
||||
const std::optional<json> & messages_override = std::nullopt,
|
||||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt) {
|
||||
jinja::context ctx(tmpl.source());
|
||||
|
||||
nlohmann::ordered_json inp = nlohmann::ordered_json{
|
||||
@@ -814,6 +816,12 @@ std::string common_chat_template_direct_apply(
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
return common_chat_template_direct_apply_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
@@ -864,7 +872,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "[THINK]";
|
||||
data.thinking_end_tag = "[/THINK]";
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
"[THINK]",
|
||||
@@ -947,7 +955,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
|
||||
auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
auto prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
|
||||
// Check if we need to replace the return token with end token during
|
||||
// inference and without generation prompt. For more details see:
|
||||
@@ -1069,12 +1077,137 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_gemma4(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<|channel>",
|
||||
"<channel|>",
|
||||
"<|tool_call>",
|
||||
"<tool_call|>",
|
||||
"<|turn>",
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto start = p.rule("start", p.prefix(inputs.generation_prompt, "<|channel>"));
|
||||
|
||||
if (extract_reasoning) {
|
||||
p.rule("thought", p.literal("<|channel>thought\n") + p.reasoning(p.until("<channel|>")) + p.literal("<channel|>"));
|
||||
} else {
|
||||
p.rule("thought", p.content(p.literal("<|channel>thought\n") + p.until("<channel|>") + p.literal("<channel|>")));
|
||||
}
|
||||
|
||||
auto thought = (p.peek(p.literal("<|channel>")) + p.ref("thought")) | p.negate(p.literal("<|channel>"));
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.literal("```json") <<
|
||||
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)) <<
|
||||
p.literal("```");
|
||||
return start + p.optional(thought) + response_format;
|
||||
}
|
||||
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
// Gemma4 tool calling syntax
|
||||
// Rules should match traversal logic in gemma4_to_json()
|
||||
p.rule("gemma4-string-content", p.until("<|\"|>"));
|
||||
p.rule("gemma4-string", p.literal("<|\"|>") + p.ref("gemma4-string-content") + p.literal("<|\"|>"));
|
||||
p.rule("gemma4-bool", p.json_bool());
|
||||
p.rule("gemma4-null", p.json_null());
|
||||
p.rule("gemma4-number", p.json_number());
|
||||
p.rule("gemma4-dict-key", p.rule("gemma4-dict-key-name", p.chars("[^:}]", 1, -1)) + p.literal(":"));
|
||||
p.rule("gemma4-dict-kv", p.ref("gemma4-dict-key") + p.space() + p.ref("gemma4-value"));
|
||||
p.rule("gemma4-dict", [&]() {
|
||||
auto ws = p.space();
|
||||
auto member = p.ref("gemma4-dict-kv");
|
||||
auto members = p.sequence({member, p.zero_or_more(p.sequence({p.literal(","), ws, member}))});
|
||||
return p.sequence({
|
||||
p.literal("{"), ws,
|
||||
p.choice({p.literal("}"), p.sequence({members, ws, p.literal("}")})})
|
||||
});
|
||||
});
|
||||
p.rule("gemma4-array", [&]() {
|
||||
auto ws = p.space();
|
||||
auto value = p.ref("gemma4-value");
|
||||
auto elements = p.sequence({value, p.zero_or_more(p.sequence({p.literal(","), ws, value}))});
|
||||
return p.sequence({
|
||||
p.literal("["), ws,
|
||||
p.choice({p.literal("]"), p.sequence({elements, ws, p.literal("]")})})
|
||||
});
|
||||
});
|
||||
p.rule("gemma4-value", [&]() {
|
||||
return p.choice({
|
||||
p.ref("gemma4-string"), p.ref("gemma4-dict"), p.ref("gemma4-array"),
|
||||
p.ref("gemma4-number"), p.ref("gemma4-bool"), p.ref("gemma4-null")
|
||||
});
|
||||
});
|
||||
|
||||
auto tool_choice = p.choice();
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
// TODO @aldehir : need to extend json-schema-to-grammar to produce more than JSON rules
|
||||
// const auto & params = function.at("parameters");
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, p.tool(p.sequence({
|
||||
p.tool_open(p.tool_name(p.literal(name)) + p.peek(p.literal("{"))),
|
||||
p.tool_args(p.ref("gemma4-dict")),
|
||||
})));
|
||||
});
|
||||
|
||||
auto tool_call = p.trigger_rule("tool-call", p.repeat(
|
||||
"<|tool_call>call:" + tool_choice + "<tool_call|>",
|
||||
/* min = */ inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0,
|
||||
/* max = */ inputs.parallel_tool_calls ? -1 : 1
|
||||
));
|
||||
|
||||
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<|tool_call>"})));
|
||||
auto message = p.rule("message", thought + content);
|
||||
return start + p.zero_or_more(message) + tool_call;
|
||||
}
|
||||
|
||||
auto content = p.rule("content", p.content(p.until("<|channel>")));
|
||||
auto message = p.rule("message", thought + content);
|
||||
return start + p.one_or_more(message);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_call>" },
|
||||
};
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
// Functionary v3.2 - uses recipient-based format: >>>recipient\n{content}
|
||||
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
">>>all",
|
||||
@@ -1168,7 +1301,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
@@ -1291,7 +1424,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
@@ -1370,7 +1503,7 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
@@ -1441,7 +1574,7 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = false;
|
||||
data.preserved_tokens = {
|
||||
@@ -1548,46 +1681,146 @@ static void requires_non_null_content(json & messages) {
|
||||
}
|
||||
|
||||
// Gemma4 uses a custom tool_responses field instead of role:tool messages.
|
||||
// Convert consecutive role:tool messages into a single user message with tool_responses.
|
||||
//
|
||||
// This will transform a sequence of messages:
|
||||
// assistant(tool_call+) -> tool+ -> assistant(content)
|
||||
//
|
||||
// Into a single assistant message containing a tool_responses field:
|
||||
// assistant(content + tool_call + tool_responses)
|
||||
//
|
||||
// This is necessary for the Gemma4 chat template to properly format the prompt.
|
||||
// See https://ai.google.dev/gemma/docs/core/prompt-formatting-gemma4
|
||||
struct gemma4_model_turn_builder {
|
||||
json & messages;
|
||||
size_t pos;
|
||||
json tool_calls = json::array();
|
||||
json tool_responses = json::array();
|
||||
json content;
|
||||
json reasoning_content;
|
||||
|
||||
gemma4_model_turn_builder(json & msgs, size_t pos) : messages(msgs), pos(pos) {}
|
||||
|
||||
void collect() {
|
||||
// Collect the first assistant message
|
||||
auto & msg = messages[pos];
|
||||
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
|
||||
// According to the prompt formatting guide, we need to preserve reasoning_content
|
||||
// between function calls. The current chat templates do not support this, but we will do it anyway.
|
||||
reasoning_content = msg.at("reasoning_content");
|
||||
}
|
||||
for (auto & tc : msg.at("tool_calls")) {
|
||||
tool_calls.push_back(tc);
|
||||
}
|
||||
pos++;
|
||||
|
||||
// Collect tool call results
|
||||
while (pos < messages.size() && messages[pos].value("role", "") == "tool") {
|
||||
collect_result(messages[pos]);
|
||||
pos++;
|
||||
}
|
||||
|
||||
// Check if the next assistant message is the final message
|
||||
if (pos < messages.size() && messages[pos].value("role", "") == "assistant") {
|
||||
auto & next = messages[pos];
|
||||
if (!has_tool_calls(next) && has_content(next)) {
|
||||
content = next.at("content");
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void collect_result(const json & curr) {
|
||||
json response;
|
||||
if (curr.contains("content")) {
|
||||
const auto & content = curr.at("content");
|
||||
if (content.is_string()) {
|
||||
// Try to parse the content as JSON; fall back to raw string
|
||||
try {
|
||||
response = json::parse(content.get<std::string>());
|
||||
} catch (...) {
|
||||
response = content;
|
||||
}
|
||||
} else {
|
||||
response = content;
|
||||
}
|
||||
}
|
||||
|
||||
std::string name;
|
||||
|
||||
// Match name with corresponding tool call
|
||||
size_t idx = tool_responses.size();
|
||||
if (idx < tool_calls.size()) {
|
||||
auto & tc = tool_calls[idx];
|
||||
if (tc.contains("function")) {
|
||||
name = tc.at("function").value("name", "");
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to the tool call id
|
||||
if (name.empty()) {
|
||||
name = curr.value("tool_call_id", "");
|
||||
}
|
||||
|
||||
tool_responses.push_back({{"name", name}, {"response", response}});
|
||||
}
|
||||
|
||||
json build() {
|
||||
collect();
|
||||
|
||||
json msg = {
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", tool_calls},
|
||||
};
|
||||
if (!tool_responses.empty()) {
|
||||
msg["tool_responses"] = tool_responses;
|
||||
}
|
||||
if (!content.is_null()) {
|
||||
msg["content"] = content;
|
||||
}
|
||||
if (!reasoning_content.is_null()) {
|
||||
msg["reasoning_content"] = reasoning_content;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
static bool has_content(const json & msg) {
|
||||
if (!msg.contains("content") || msg.at("content").is_null()) {
|
||||
return false;
|
||||
}
|
||||
const auto & content = msg.at("content");
|
||||
if (content.is_string() && !content.get<std::string>().empty()) {
|
||||
return true;
|
||||
}
|
||||
if (content.is_array() && !content.empty()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool has_tool_calls(const json & msg) {
|
||||
return msg.contains("tool_calls") && msg.at("tool_calls").is_array() && !msg.at("tool_calls").empty();
|
||||
}
|
||||
};
|
||||
|
||||
static void convert_tool_responses_gemma4(json & messages) {
|
||||
json result = json::array();
|
||||
size_t i = 0;
|
||||
|
||||
while (i < messages.size()) {
|
||||
if (messages[i].contains("role") && messages[i].at("role") == "tool") {
|
||||
json tool_responses = json::array();
|
||||
while (i < messages.size() &&
|
||||
messages[i].contains("role") &&
|
||||
messages[i].at("role") == "tool") {
|
||||
const auto & tool_msg = messages[i];
|
||||
std::string name;
|
||||
if (tool_msg.contains("tool_call_id") && tool_msg.at("tool_call_id").is_string()) {
|
||||
name = tool_msg.at("tool_call_id");
|
||||
} else if (tool_msg.contains("name") && tool_msg.at("name").is_string()) {
|
||||
name = tool_msg.at("name");
|
||||
}
|
||||
json response;
|
||||
if (tool_msg.contains("content")) {
|
||||
const auto & content = tool_msg.at("content");
|
||||
if (content.is_string()) {
|
||||
// Try to parse the content as JSON; fall back to raw string
|
||||
try {
|
||||
response = json::parse(content.get<std::string>());
|
||||
} catch (...) {
|
||||
response = content;
|
||||
}
|
||||
} else {
|
||||
response = content;
|
||||
}
|
||||
}
|
||||
tool_responses.push_back({{"name", name}, {"response", response}});
|
||||
i++;
|
||||
}
|
||||
result.push_back({{"role", "user"}, {"tool_responses", tool_responses}});
|
||||
} else {
|
||||
result.push_back(messages[i]);
|
||||
auto & msg = messages[i];
|
||||
|
||||
if (msg.value("role", "") != "assistant" || !msg.contains("tool_calls") ||
|
||||
!msg.at("tool_calls").is_array() || msg.at("tool_calls").empty()) {
|
||||
result.push_back(msg);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
gemma4_model_turn_builder builder(messages, i);
|
||||
result.push_back(builder.build());
|
||||
i = builder.pos;
|
||||
}
|
||||
|
||||
messages = result;
|
||||
}
|
||||
|
||||
@@ -1623,10 +1856,10 @@ static json common_chat_extra_context() {
|
||||
return ctx;
|
||||
}
|
||||
|
||||
static std::optional<common_chat_params> try_specialized_template(
|
||||
std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::string & src,
|
||||
const autoparser::generation_params & params) {
|
||||
autoparser::generation_params & params) {
|
||||
// Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser
|
||||
// Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them
|
||||
if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos &&
|
||||
@@ -1679,6 +1912,12 @@ static std::optional<common_chat_params> try_specialized_template(
|
||||
return common_chat_params_init_gigachat_v3(tmpl, params);
|
||||
}
|
||||
|
||||
// Gemma4 format detection
|
||||
if (src.find("'<|tool_call>call:'") != std::string::npos) {
|
||||
workaround::convert_tool_responses_gemma4(params.messages);
|
||||
return common_chat_params_init_gemma4(tmpl, params);
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
@@ -1719,16 +1958,12 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
workaround::func_args_not_string(params.messages);
|
||||
}
|
||||
|
||||
if (src.find("'<|tool_call>call:'") != std::string::npos) {
|
||||
workaround::convert_tool_responses_gemma4(params.messages);
|
||||
}
|
||||
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
|
||||
params.generation_prompt = diff.right;
|
||||
params.generation_prompt = diff.right + diff.suffix;
|
||||
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
|
||||
@@ -1760,7 +1995,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
common_chat_params data;
|
||||
auto params_copy = params;
|
||||
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, params_copy);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.generation_prompt = params.generation_prompt;
|
||||
auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) {
|
||||
@@ -1770,7 +2005,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
return data;
|
||||
}
|
||||
|
||||
if (auto result = try_specialized_template(tmpl, src, params)) {
|
||||
if (auto result = common_chat_try_specialized_template(tmpl, src, params)) {
|
||||
result->generation_prompt = params.generation_prompt;
|
||||
return *result;
|
||||
}
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "jinja/parser.h"
|
||||
#include "nlohmann/json_fwd.hpp"
|
||||
#include "peg-parser.h"
|
||||
#include "jinja/parser.h"
|
||||
#include "jinja/runtime.h"
|
||||
#include "jinja/caps.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include "nlohmann/json_fwd.hpp"
|
||||
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
@@ -19,8 +19,6 @@
|
||||
using chat_template_caps = jinja::caps;
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
struct common_chat_templates;
|
||||
|
||||
namespace autoparser {
|
||||
@@ -75,41 +73,9 @@ struct common_chat_template {
|
||||
const std::string & bos_token() const { return bos_tok; }
|
||||
const std::string & eos_token() const { return eos_tok; }
|
||||
|
||||
// TODO: this is ugly, refactor it somehow
|
||||
json add_system(const json & messages, const std::string & system_prompt) const {
|
||||
GGML_ASSERT(messages.is_array());
|
||||
auto msgs_copy = messages;
|
||||
if (!caps.supports_system_role) {
|
||||
if (msgs_copy.empty()) {
|
||||
msgs_copy.insert(msgs_copy.begin(), json{
|
||||
{"role", "user"},
|
||||
{"content", system_prompt}
|
||||
});
|
||||
} else {
|
||||
auto & first_msg = msgs_copy[0];
|
||||
if (!first_msg.contains("content")) {
|
||||
first_msg["content"] = "";
|
||||
}
|
||||
first_msg["content"] = system_prompt + "\n\n"
|
||||
+ first_msg["content"].get<std::string>();
|
||||
}
|
||||
} else {
|
||||
if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") {
|
||||
msgs_copy.insert(msgs_copy.begin(), json{
|
||||
{"role", "system"},
|
||||
{"content", system_prompt}
|
||||
});
|
||||
} else if (msgs_copy[0].at("role") == "system") {
|
||||
msgs_copy[0]["content"] = system_prompt;
|
||||
}
|
||||
}
|
||||
return msgs_copy;
|
||||
}
|
||||
|
||||
chat_template_caps original_caps() const {
|
||||
return caps;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct common_chat_msg {
|
||||
@@ -257,8 +223,8 @@ common_chat_templates_ptr common_chat_templates_init(const struct llama_model *
|
||||
const std::string & bos_token_override = "",
|
||||
const std::string & eos_token_override = "");
|
||||
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
|
||||
|
||||
struct common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs);
|
||||
@@ -275,9 +241,9 @@ std::string common_chat_format_example(const struct common_chat_templates *
|
||||
bool use_jinja,
|
||||
const std::map<std::string, std::string> & chat_template_kwargs);
|
||||
|
||||
const char * common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
const char * common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
|
||||
// used by arg and server
|
||||
const char * common_reasoning_format_name(common_reasoning_format format);
|
||||
@@ -303,7 +269,9 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
|
||||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override = std::nullopt,
|
||||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt);
|
||||
const autoparser::generation_params & inputs);
|
||||
|
||||
std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::string & src,
|
||||
autoparser::generation_params & params);
|
||||
|
||||
@@ -579,8 +579,9 @@ struct common_params {
|
||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||
bool cache_prompt = true; // whether to enable prompt caching
|
||||
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
|
||||
int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill
|
||||
bool clear_idle = true; // save and clear idle slots upon starting a new task
|
||||
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
|
||||
int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
|
||||
@@ -700,13 +700,13 @@ namespace console {
|
||||
std::vector<std::string> entries;
|
||||
size_t viewing_idx = SIZE_MAX;
|
||||
std::string backup_line; // current line before viewing history
|
||||
void add(const std::string & line) {
|
||||
void add(std::string_view line) {
|
||||
if (line.empty()) {
|
||||
return;
|
||||
}
|
||||
// avoid duplicates with the last entry
|
||||
if (entries.empty() || entries.back() != line) {
|
||||
entries.push_back(line);
|
||||
entries.emplace_back(line);
|
||||
}
|
||||
// also clear viewing state
|
||||
end_viewing();
|
||||
@@ -1031,11 +1031,12 @@ namespace console {
|
||||
|
||||
if (!end_of_stream && !line.empty()) {
|
||||
// remove the trailing newline for history storage
|
||||
std::string_view hline = line;
|
||||
if (!line.empty() && line.back() == '\n') {
|
||||
line.pop_back();
|
||||
hline.remove_suffix(1);
|
||||
}
|
||||
// TODO: maybe support multiline history entries?
|
||||
history.add(line);
|
||||
history.add(hline);
|
||||
}
|
||||
|
||||
fflush(out);
|
||||
|
||||
@@ -591,14 +591,25 @@ static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
|
||||
for (const auto & f : files) {
|
||||
if (gguf_filename_is_model(f.path) &&
|
||||
std::regex_search(f.path, pattern)) {
|
||||
auto split = get_gguf_split_info(f.path);
|
||||
if (split.count > 1 && split.index != 1) {
|
||||
continue;
|
||||
}
|
||||
return f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & f : files) {
|
||||
if (gguf_filename_is_model(f.path)) {
|
||||
return f;
|
||||
// fallback to first available model only if tag is empty
|
||||
if (tag.empty()) {
|
||||
for (const auto & f : files) {
|
||||
if (gguf_filename_is_model(f.path)) {
|
||||
auto split = get_gguf_split_info(f.path);
|
||||
if (split.count > 1 && split.index != 1) {
|
||||
continue;
|
||||
}
|
||||
return f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -615,6 +626,7 @@ static void list_available_gguf_files(const hf_cache::hf_files & files) {
|
||||
}
|
||||
|
||||
struct hf_plan {
|
||||
hf_cache::hf_file primary;
|
||||
hf_cache::hf_files model_files;
|
||||
hf_cache::hf_file mmproj;
|
||||
};
|
||||
@@ -660,6 +672,7 @@ static hf_plan get_hf_plan(const common_params_model & model,
|
||||
}
|
||||
}
|
||||
|
||||
plan.primary = primary;
|
||||
plan.model_files = get_split_files(all, primary);
|
||||
|
||||
if (opts.download_mmproj) {
|
||||
@@ -746,7 +759,7 @@ common_download_model_result common_download_model(const common_params_model
|
||||
for (const auto & f : hf.model_files) {
|
||||
hf_cache::finalize_file(f);
|
||||
}
|
||||
result.model_path = hf.model_files[0].final_path;
|
||||
result.model_path = hf.primary.final_path;
|
||||
|
||||
if (!hf.mmproj.path.empty()) {
|
||||
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
|
||||
|
||||
@@ -251,6 +251,23 @@ value binary_expression::execute_impl(context & ctx) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// Python-style string repetition
|
||||
// TODO: support array/tuple repetition (e.g., [1, 2] * 3 → [1, 2, 1, 2, 1, 2])
|
||||
if (op.value == "*" &&
|
||||
((is_val<value_string>(left_val) && is_val<value_int>(right_val)) ||
|
||||
(is_val<value_int>(left_val) && is_val<value_string>(right_val)))) {
|
||||
const auto & str = is_val<value_string>(left_val) ? left_val->as_string() : right_val->as_string();
|
||||
const int64_t repeat = is_val<value_int>(right_val) ? right_val->as_int() : left_val->as_int();
|
||||
auto res = mk_val<value_string>();
|
||||
if (repeat <= 0) {
|
||||
return res;
|
||||
}
|
||||
for (int64_t i = 0; i < repeat; ++i) {
|
||||
res->val_str = res->val_str.append(str);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// String membership
|
||||
if (is_val<value_string>(left_val) && is_val<value_string>(right_val)) {
|
||||
// case: "a" in "abc"
|
||||
@@ -306,6 +323,19 @@ value filter_expression::execute_impl(context & ctx) {
|
||||
filter_id = "strip"; // alias
|
||||
}
|
||||
JJ_DEBUG("Applying filter '%s' to %s", filter_id.c_str(), input->type().c_str());
|
||||
// TODO: Refactor filters so this coercion can be done automatically
|
||||
if (!input->is_undefined() && !is_val<value_string>(input) && (
|
||||
filter_id == "capitalize" ||
|
||||
filter_id == "lower" ||
|
||||
filter_id == "replace" ||
|
||||
filter_id == "strip" ||
|
||||
filter_id == "title" ||
|
||||
filter_id == "upper" ||
|
||||
filter_id == "wordcount"
|
||||
)) {
|
||||
JJ_DEBUG("Coercing %s to String for '%s' filter", input->type().c_str(), filter_id.c_str());
|
||||
input = mk_val<value_string>(input->as_string());
|
||||
}
|
||||
return try_builtin_func(ctx, filter_id, input)->invoke(func_args(ctx));
|
||||
|
||||
} else if (is_stmt<call_expression>(filter)) {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "runtime.h"
|
||||
#include "unicode.h"
|
||||
#include "value.h"
|
||||
|
||||
// for converting from JSON to jinja values
|
||||
@@ -154,6 +155,83 @@ static value test_compare_fn(const func_args & args) {
|
||||
return mk_val<value_bool>(value_compare(args.get_pos(0), args.get_pos(1), op));
|
||||
}
|
||||
|
||||
static void append_codepoint_as_ascii_json_escape(std::string & out, uint32_t codepoint) {
|
||||
auto append_u16 = [&out](uint32_t value) {
|
||||
char buf[8];
|
||||
snprintf(buf, sizeof(buf), "\\u%04x", static_cast<unsigned int>(value));
|
||||
out += buf;
|
||||
};
|
||||
|
||||
if (codepoint <= 0xFFFF) {
|
||||
append_u16(codepoint);
|
||||
return;
|
||||
}
|
||||
|
||||
codepoint -= 0x10000;
|
||||
append_u16(0xD800 + ((codepoint >> 10) & 0x3FF));
|
||||
append_u16(0xDC00 + (codepoint & 0x3FF));
|
||||
}
|
||||
|
||||
static std::string json_ensure_ascii_preserving_format(const std::string & json_str) {
|
||||
std::string output;
|
||||
output.reserve(json_str.size());
|
||||
|
||||
bool in_string = false;
|
||||
bool escaped = false;
|
||||
|
||||
for (size_t pos = 0; pos < json_str.size();) {
|
||||
const char ch = json_str[pos];
|
||||
if (!in_string) {
|
||||
output.push_back(ch);
|
||||
if (ch == '"') {
|
||||
in_string = true;
|
||||
}
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (escaped) {
|
||||
output.push_back(ch);
|
||||
escaped = false;
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ch == '\\') {
|
||||
output.push_back(ch);
|
||||
escaped = true;
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ch == '"') {
|
||||
output.push_back(ch);
|
||||
in_string = false;
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
const unsigned char uch = static_cast<unsigned char>(ch);
|
||||
if (uch < 0x80) {
|
||||
output.push_back(ch);
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto parsed = common_parse_utf8_codepoint(json_str, pos);
|
||||
if (parsed.status != utf8_parse_result::SUCCESS) {
|
||||
output += "\\ufffd";
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
append_codepoint_as_ascii_json_escape(output, parsed.codepoint);
|
||||
pos += parsed.bytes_consumed;
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
static value tojson(const func_args & args) {
|
||||
args.ensure_count(1, 5);
|
||||
value val_ascii = args.get_kwarg_or_pos("ensure_ascii", 1);
|
||||
@@ -169,16 +247,17 @@ static value tojson(const func_args & args) {
|
||||
if (is_val<value_int>(val_indent)) {
|
||||
indent = static_cast<int>(val_indent->as_int());
|
||||
}
|
||||
if (val_ascii->as_bool()) { // undefined == false
|
||||
throw not_implemented_exception("tojson ensure_ascii=true not implemented");
|
||||
}
|
||||
if (val_sort->as_bool()) { // undefined == false
|
||||
throw not_implemented_exception("tojson sort_keys=true not implemented");
|
||||
}
|
||||
const bool ensure_ascii = val_ascii->as_bool(); // undefined == false
|
||||
auto separators = (is_val<value_array>(val_separators) ? val_separators : mk_val<value_array>())->as_array();
|
||||
std::string item_sep = separators.size() > 0 ? separators[0]->as_string().str() : (indent < 0 ? ", " : ",");
|
||||
std::string key_sep = separators.size() > 1 ? separators[1]->as_string().str() : ": ";
|
||||
std::string json_str = value_to_json(args.get_pos(0), indent, item_sep, key_sep);
|
||||
if (ensure_ascii) {
|
||||
json_str = json_ensure_ascii_preserving_format(json_str);
|
||||
}
|
||||
return mk_val<value_string>(json_str);
|
||||
}
|
||||
|
||||
@@ -460,13 +539,18 @@ const func_builtins & value_int_t::get_builtins() const {
|
||||
int64_t val = args.get_pos(0)->as_int();
|
||||
return mk_val<value_int>(val < 0 ? -val : val);
|
||||
}},
|
||||
{"int", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_int>();
|
||||
return mk_val<value_int>(args.get_pos(0)->as_int());
|
||||
}},
|
||||
{"float", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_int>();
|
||||
double val = static_cast<double>(args.get_pos(0)->as_int());
|
||||
return mk_val<value_float>(val);
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"safe", tojson},
|
||||
{"string", tojson},
|
||||
{"tojson", tojson},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
@@ -485,8 +569,13 @@ const func_builtins & value_float_t::get_builtins() const {
|
||||
int64_t val = static_cast<int64_t>(args.get_pos(0)->as_float());
|
||||
return mk_val<value_int>(val);
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"float", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_float>();
|
||||
return mk_val<value_float>(args.get_pos(0)->as_float());
|
||||
}},
|
||||
{"safe", tojson},
|
||||
{"string", tojson},
|
||||
{"tojson", tojson},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
@@ -771,6 +860,11 @@ const func_builtins & value_string_t::get_builtins() const {
|
||||
|
||||
|
||||
const func_builtins & value_bool_t::get_builtins() const {
|
||||
static const func_handler tostring = [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_bool>();
|
||||
bool val = args.get_pos(0)->as_bool();
|
||||
return mk_val<value_string>(val ? "True" : "False");
|
||||
};
|
||||
static const func_builtins builtins = {
|
||||
{"default", default_value},
|
||||
{"int", [](const func_args & args) -> value {
|
||||
@@ -783,11 +877,8 @@ const func_builtins & value_bool_t::get_builtins() const {
|
||||
bool val = args.get_pos(0)->as_bool();
|
||||
return mk_val<value_float>(val ? 1.0 : 0.0);
|
||||
}},
|
||||
{"string", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_bool>();
|
||||
bool val = args.get_pos(0)->as_bool();
|
||||
return mk_val<value_string>(val ? "True" : "False");
|
||||
}},
|
||||
{"safe", tostring},
|
||||
{"string", tostring},
|
||||
{"tojson", tojson},
|
||||
};
|
||||
return builtins;
|
||||
@@ -1100,18 +1191,14 @@ const func_builtins & value_object_t::get_builtins() const {
|
||||
}
|
||||
|
||||
const func_builtins & value_none_t::get_builtins() const {
|
||||
static const func_handler tostring = [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
};
|
||||
static const func_builtins builtins = {
|
||||
{"default", default_value},
|
||||
{"tojson", tojson},
|
||||
{"string", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"safe", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"strip", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"string", tostring},
|
||||
{"safe", tostring},
|
||||
{"items", empty_value_fn<value_array>},
|
||||
{"map", empty_value_fn<value_array>},
|
||||
{"reject", empty_value_fn<value_array>},
|
||||
|
||||
@@ -256,6 +256,38 @@ static std::pair<std::vector<common_peg_chars_parser::char_range>, bool> parse_c
|
||||
return {ranges, negated};
|
||||
}
|
||||
|
||||
common_peg_ast_id common_peg_ast_arena::find_by_tag(const common_peg_ast_node & parent, const std::string & tag, int max_depth) const {
|
||||
for (auto child_id : parent.children) {
|
||||
const auto & child = get(child_id);
|
||||
if (child.tag == tag) {
|
||||
return child_id;
|
||||
}
|
||||
if (max_depth > 1) {
|
||||
auto result = find_by_tag(child, tag, max_depth - 1);
|
||||
if (result != COMMON_PEG_INVALID_AST_ID) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
return COMMON_PEG_INVALID_AST_ID;
|
||||
}
|
||||
|
||||
common_peg_ast_id common_peg_ast_arena::find_by_rule(const common_peg_ast_node & parent, const std::string & rule, int max_depth) const {
|
||||
for (auto child_id : parent.children) {
|
||||
const auto & child = get(child_id);
|
||||
if (child.rule == rule) {
|
||||
return child_id;
|
||||
}
|
||||
if (max_depth > 1) {
|
||||
auto result = find_by_rule(child, rule, max_depth - 1);
|
||||
if (result != COMMON_PEG_INVALID_AST_ID) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
return COMMON_PEG_INVALID_AST_ID;
|
||||
}
|
||||
|
||||
void common_peg_ast_arena::visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const {
|
||||
if (id == COMMON_PEG_INVALID_AST_ID) {
|
||||
return;
|
||||
@@ -1561,7 +1593,23 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
if (!s.schema) {
|
||||
return true;
|
||||
}
|
||||
if (s.raw && s.schema->contains("type") && s.schema->at("type").is_string() && s.schema->at("type") == "string") {
|
||||
if (s.raw && s.schema->contains("type")) {
|
||||
const auto & type_val = s.schema->at("type");
|
||||
if (type_val.is_string() && type_val == "string") {
|
||||
return true;
|
||||
}
|
||||
// Handle nullable types like ["string", "null"] - delegate when the
|
||||
// non-null type is string, since the tagged format uses raw text
|
||||
if (type_val.is_array()) {
|
||||
for (const auto & t : type_val) {
|
||||
if (t.is_string() && t.get<std::string>() != "null") {
|
||||
return t.get<std::string>() == "string";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Delegate for enum schemas in raw mode - enum values are literal strings
|
||||
if (s.raw && !s.schema->contains("type") && s.schema->contains("enum")) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
@@ -106,6 +106,9 @@ class common_peg_ast_arena {
|
||||
|
||||
const common_peg_ast_node & get(common_peg_ast_id id) const { return nodes_.at(id); }
|
||||
|
||||
common_peg_ast_id find_by_tag(const common_peg_ast_node & parent, const std::string & tag, int max_depth = 3) const;
|
||||
common_peg_ast_id find_by_rule(const common_peg_ast_node & parent, const std::string & tag, int max_depth = 3) const;
|
||||
|
||||
size_t size() const { return nodes_.size(); }
|
||||
|
||||
void clear() { nodes_.clear(); }
|
||||
|
||||
@@ -1229,15 +1229,15 @@ class TextModel(ModelBase):
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
||||
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
assert max(tokenizer.vocab.values()) < vocab_size
|
||||
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) # ty: ignore[unresolved-attribute]
|
||||
assert max(tokenizer.vocab.values()) < vocab_size # ty: ignore[unresolved-attribute]
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} # ty: ignore[unresolved-attribute]
|
||||
added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute]
|
||||
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder # ty: ignore[unresolved-attribute]
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
@@ -1250,7 +1250,7 @@ class TextModel(ModelBase):
|
||||
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
|
||||
if not added_tokens_decoder[i].normalized:
|
||||
previous_token = token
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) # ty: ignore[unresolved-attribute, invalid-assignment]
|
||||
if previous_token != token:
|
||||
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
|
||||
|
||||
@@ -1583,13 +1583,13 @@ class TextModel(ModelBase):
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
vocab_size = hparams["vocab_size"]
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size # ty: ignore[unresolved-attribute]
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.mergeable_ranks
|
||||
mergeable_ranks = tokenizer.mergeable_ranks # ty: ignore[unresolved-attribute]
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[QwenModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
@@ -1599,7 +1599,7 @@ class TextModel(ModelBase):
|
||||
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
|
||||
|
||||
# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
|
||||
added_vocab = tokenizer.special_tokens
|
||||
added_vocab = tokenizer.special_tokens # ty: ignore[unresolved-attribute]
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
|
||||
|
||||
for i in range(vocab_size):
|
||||
@@ -1622,10 +1622,10 @@ class TextModel(ModelBase):
|
||||
special_vocab.merges = merges
|
||||
# only add special tokens when they were not already loaded from config.json
|
||||
if len(special_vocab.special_token_ids) == 0:
|
||||
special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
|
||||
special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
# this one is usually not in config.json anyway
|
||||
special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
|
||||
special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_sentencepiece(self, add_to_gguf=True):
|
||||
@@ -1877,10 +1877,10 @@ class TextModel(ModelBase):
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_glm(self):
|
||||
@@ -1894,10 +1894,10 @@ class TextModel(ModelBase):
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
# Special tokens
|
||||
# Note: Using <|endoftext|> (151329) for eot causes endless generation
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
|
||||
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # ty: ignore[unresolved-attribute] # 151331
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # ty: ignore[unresolved-attribute] # 151336
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute] # 151329
|
||||
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # ty: ignore[unresolved-attribute] # 151338
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_interns1(self):
|
||||
@@ -1906,16 +1906,16 @@ class TextModel(ModelBase):
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
|
||||
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab()) # ty: ignore[unresolved-attribute]
|
||||
vocab_size = self.hparams.get("vocab_size", len(vocab))
|
||||
assert max(vocab.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute]
|
||||
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder # ty: ignore[unresolved-attribute]
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
@@ -1928,7 +1928,7 @@ class TextModel(ModelBase):
|
||||
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
|
||||
if not added_tokens_decoder[i].normalized:
|
||||
previous_token = token
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) # ty: ignore[unresolved-attribute, invalid-assignment]
|
||||
if previous_token != token:
|
||||
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
|
||||
|
||||
@@ -2219,10 +2219,10 @@ class MmprojModel(ModelBase):
|
||||
self.image_size = self.find_vparam(["image_size"])
|
||||
self.gguf_writer.add_vision_image_size(self.image_size)
|
||||
self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"]))
|
||||
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size", "vt_hidden_size"]))
|
||||
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size", "width", "vt_hidden_size"]))
|
||||
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"]))
|
||||
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
|
||||
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "vt_num_attention_heads"]))
|
||||
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "heads", "vt_num_attention_heads"]))
|
||||
|
||||
# preprocessor config
|
||||
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
|
||||
@@ -2516,15 +2516,15 @@ class XverseModel(TextModel):
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model)
|
||||
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab)) # ty: ignore[unresolved-attribute]
|
||||
# Since we are checking the maximum index, we need to ensure it's strictly less than vocab_size,
|
||||
# because vocab_size is the count of items, and indexes start at 0.
|
||||
max_vocab_index = max(tokenizer.get_vocab().values())
|
||||
max_vocab_index = max(tokenizer.get_vocab().values()) # ty: ignore[unresolved-attribute]
|
||||
if max_vocab_index >= vocab_size:
|
||||
raise ValueError("Vocabulary size exceeds expected maximum size.")
|
||||
|
||||
reverse_vocab: dict[int, str] = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
reverse_vocab: dict[int, str] = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} # ty: ignore[unresolved-attribute]
|
||||
added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute]
|
||||
|
||||
for token_id in range(vocab_size):
|
||||
token_text = reverse_vocab[token_id].encode('utf-8')
|
||||
@@ -2535,7 +2535,7 @@ class XverseModel(TextModel):
|
||||
elif re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
|
||||
toktype = gguf.TokenType.BYTE # special
|
||||
elif reverse_vocab[token_id] in added_vocab:
|
||||
if tokenizer.added_tokens_decoder[token_id].special:
|
||||
if tokenizer.added_tokens_decoder[token_id].special: # ty: ignore[unresolved-attribute]
|
||||
toktype = gguf.TokenType.CONTROL
|
||||
else:
|
||||
toktype = gguf.TokenType.USER_DEFINED
|
||||
@@ -3752,7 +3752,7 @@ class QwenModel(TextModel):
|
||||
|
||||
@staticmethod
|
||||
def token_bytes_to_string(b):
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode # ty: ignore[unresolved-import]
|
||||
byte_encoder = bytes_to_unicode()
|
||||
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
|
||||
|
||||
@@ -3777,7 +3777,14 @@ class QwenModel(TextModel):
|
||||
self._set_vocab_qwen()
|
||||
|
||||
|
||||
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration", "KORMoForCausalLM", "AudioFlamingo3ForConditionalGeneration")
|
||||
@ModelBase.register(
|
||||
"Qwen2Model",
|
||||
"Qwen2ForCausalLM",
|
||||
"Qwen2AudioForConditionalGeneration",
|
||||
"KORMoForCausalLM",
|
||||
"AudioFlamingo3ForConditionalGeneration",
|
||||
"DotsOCRForCausalLM",
|
||||
)
|
||||
class Qwen2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN2
|
||||
|
||||
@@ -3798,7 +3805,8 @@ class Qwen2Model(TextModel):
|
||||
name = name.replace("language_model.", "") # for InternVL
|
||||
if name.startswith("mlp") or name.startswith("multi_modal_projector") \
|
||||
or name.startswith("vision_model") or name.startswith("audio_tower") \
|
||||
or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
|
||||
or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") \
|
||||
or name.startswith("vision_tower."):
|
||||
# skip vision and audio tensors
|
||||
return
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
@@ -3815,14 +3823,14 @@ class DreamModel(TextModel):
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
|
||||
vocab_dict = tokenizer.get_vocab()
|
||||
vocab_dict = tokenizer.get_vocab() # ty: ignore[unresolved-attribute]
|
||||
vocab_size = self.hparams.get("vocab_size", len(vocab_dict))
|
||||
assert max(vocab_dict.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute]
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
@@ -3880,14 +3888,14 @@ class LLaDAModel(TextModel):
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
|
||||
vocab_dict = tokenizer.get_vocab()
|
||||
vocab_dict = tokenizer.get_vocab() # ty: ignore[unresolved-attribute]
|
||||
vocab_size = self.hparams.get("vocab_size", len(vocab_dict))
|
||||
assert max(vocab_dict.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute]
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
@@ -4665,9 +4673,9 @@ class Qwen3Model(Qwen2Model):
|
||||
|
||||
self.is_rerank = True
|
||||
self.is_tied_embeddings = self.hparams.get("tie_word_embeddings", False)
|
||||
self.token_false_id = tokenizer.convert_tokens_to_ids("no")
|
||||
self.token_true_id = tokenizer.convert_tokens_to_ids("yes")
|
||||
self.sep_token_id = tokenizer.convert_tokens_to_ids("|")
|
||||
self.token_false_id = tokenizer.convert_tokens_to_ids("no") # ty: ignore[unresolved-attribute, invalid-assignment]
|
||||
self.token_true_id = tokenizer.convert_tokens_to_ids("yes") # ty: ignore[unresolved-attribute, invalid-assignment]
|
||||
self.sep_token_id = tokenizer.convert_tokens_to_ids("|") # ty: ignore[unresolved-attribute]
|
||||
|
||||
assert self.token_false_id is not None and self.token_true_id is not None
|
||||
|
||||
@@ -4949,6 +4957,73 @@ class Glm4VVisionModel(Qwen3VLVisionModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("StepVLForConditionalGeneration")
|
||||
class Step3VLVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
|
||||
if not self.hparams_vision.get("intermediate_size"):
|
||||
hidden_size = self.hparams_vision.get("hidden_size") or self.hparams_vision.get("width") or 0
|
||||
assert hidden_size > 0
|
||||
mlp_ratio = float(self.hparams_vision.get("mlp_ratio", 8960 / 1536))
|
||||
self.hparams_vision["intermediate_size"] = int(round(hidden_size * mlp_ratio))
|
||||
|
||||
self.preprocessor_config.setdefault("image_mean", list(_MISTRAL_COMMON_DATASET_MEAN))
|
||||
self.preprocessor_config.setdefault("image_std", list(_MISTRAL_COMMON_DATASET_STD))
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
assert self.hparams_vision is not None
|
||||
|
||||
projector_stride = int(self.global_config.get("understand_projector_stride", -1))
|
||||
hidden_size = int(self.hparams_vision.get("hidden_size", self.hparams_vision.get("width", -1)))
|
||||
num_layers = int(self.hparams_vision.get("num_hidden_layers", self.hparams_vision.get("layers", -1)))
|
||||
assert (projector_stride, int(self.hparams_vision.get("image_size", -1)), hidden_size, num_layers) == (2, 728, 1536, 47), (
|
||||
"current Step3-VL conversion path is only validated for Step3-VL-10B"
|
||||
)
|
||||
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.STEP3VL)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(float(self.hparams_vision.get("layer_norm_eps", 1e-5)))
|
||||
self.gguf_writer.add_vision_projector_scale_factor(projector_stride ** 2)
|
||||
# 3024 max resize comes from step3-vl-10b processing_step3.py.
|
||||
self.gguf_writer.add_vision_preproc_image_size(3024)
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
if ".position_embd." in new_name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.startswith("model.") or name.startswith("lm_head."):
|
||||
return
|
||||
|
||||
if name.startswith("vision_model.vit_downsampler"):
|
||||
match = re.match(r"vision_model\.vit_downsampler(\d+)\.(weight|bias)", name)
|
||||
if match is None:
|
||||
raise ValueError(f"Unexpected Step3-VL projector tensor {name!r}")
|
||||
|
||||
proj_id = int(match.group(1)) - 1
|
||||
suffix = f".{match.group(2)}"
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, proj_id, suffix=suffix), data_torch)
|
||||
return
|
||||
|
||||
if name == "vit_large_projector.weight":
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ_FC), data_torch)
|
||||
return
|
||||
|
||||
if name.startswith("vision_model."):
|
||||
if name == "vision_model.positional_embedding":
|
||||
name += ".weight"
|
||||
elif name.endswith(".gamma") and ".ls_" in name:
|
||||
name = name.removesuffix(".gamma") + ".weight"
|
||||
|
||||
name = name.replace("attn.in_proj_weight", "attn.in_proj.weight")
|
||||
name = name.replace("attn.in_proj_bias", "attn.in_proj.bias")
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3VLForConditionalGeneration")
|
||||
class Qwen3VLTextModel(Qwen3Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3VL
|
||||
@@ -4969,6 +5044,16 @@ class Qwen3VLTextModel(Qwen3Model):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("StepVLForConditionalGeneration")
|
||||
class Step3VLTextModel(Qwen3Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.startswith("vision_model.") or name.startswith("model.vision_model.") or name.startswith("vit_large_projector."):
|
||||
return
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3VLMoeForConditionalGeneration")
|
||||
class Qwen3VLMoeTextModel(Qwen3MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE
|
||||
@@ -5859,7 +5944,7 @@ class KimiLinearModel(TextModel):
|
||||
# Build merges list using the approach similar to HunYuanMoE
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.model._mergeable_ranks
|
||||
mergeable_ranks = tokenizer.model._mergeable_ranks # ty: ignore[unresolved-attribute]
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[QwenModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
@@ -5869,7 +5954,7 @@ class KimiLinearModel(TextModel):
|
||||
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
|
||||
# Build token list
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
special_tokens = tokenizer.special_tokens
|
||||
special_tokens = tokenizer.special_tokens # ty: ignore[unresolved-attribute]
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
@@ -5895,7 +5980,7 @@ class KimiLinearModel(TextModel):
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
# override eos id in config.json with tiktoken eos id
|
||||
self.gguf_writer.add_eos_token_id(tokenizer.eos_id)
|
||||
self.gguf_writer.add_eos_token_id(tokenizer.eos_id) # ty: ignore[unresolved-attribute]
|
||||
else:
|
||||
raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!")
|
||||
|
||||
@@ -6389,11 +6474,11 @@ class BertModel(TextModel):
|
||||
with open(tokenizer_config_path, "r", encoding="utf-8") as fp:
|
||||
tokenizer_config_json = json.load(fp)
|
||||
|
||||
add_prefix = tokenizer.add_prefix_space
|
||||
remove_whitespaces = tokenizer.clean_up_tokenization_spaces
|
||||
add_prefix = tokenizer.add_prefix_space # ty: ignore[unresolved-attribute]
|
||||
remove_whitespaces = tokenizer.clean_up_tokenization_spaces # ty: ignore[unresolved-attribute]
|
||||
precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"])
|
||||
|
||||
vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size)
|
||||
vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size) # ty: ignore[unresolved-attribute]
|
||||
else:
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
@@ -6410,7 +6495,7 @@ class BertModel(TextModel):
|
||||
|
||||
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
|
||||
scores: list[float] = [-10000.0] * vocab_size
|
||||
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
|
||||
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size # ty: ignore[invalid-assignment]
|
||||
|
||||
if isinstance(tokenizer, SentencePieceProcessor):
|
||||
for token_id in range(tokenizer.vocab_size()):
|
||||
@@ -6432,20 +6517,20 @@ class BertModel(TextModel):
|
||||
scores[token_id] = score
|
||||
toktypes[token_id] = toktype
|
||||
else:
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute]
|
||||
unk_token = tokenizer_config_json.get("unk_token")
|
||||
unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3))
|
||||
unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3)) # ty: ignore[no-matching-overload]
|
||||
|
||||
for token_id in range(tokenizer.vocab_size):
|
||||
piece = tokenizer._convert_id_to_token(token_id)
|
||||
if (piece := tokenizer._convert_id_to_token(token_id)) is not None:
|
||||
for token_id in range(tokenizer.vocab_size): # ty: ignore[unresolved-attribute]
|
||||
piece = tokenizer._convert_id_to_token(token_id) # ty: ignore[unresolved-attribute]
|
||||
if (piece := tokenizer._convert_id_to_token(token_id)) is not None: # ty: ignore[unresolved-attribute]
|
||||
text = piece.encode("utf-8")
|
||||
score = tokenizer_json["model"]["vocab"][token_id][1]
|
||||
|
||||
toktype = SentencePieceTokenTypes.NORMAL
|
||||
if token_id == unk_token_id:
|
||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
||||
elif token_id in tokenizer.all_special_ids:
|
||||
elif token_id in tokenizer.all_special_ids: # ty: ignore[unresolved-attribute]
|
||||
toktype = SentencePieceTokenTypes.CONTROL
|
||||
elif token_id in added_vocab.values():
|
||||
toktype = SentencePieceTokenTypes.USER_DEFINED
|
||||
@@ -7464,9 +7549,6 @@ class Gemma4Model(Gemma3Model):
|
||||
|
||||
assert len(tokens) == vocab.vocab_size
|
||||
|
||||
# TODO @ngxson : there are some known (rare) issues with the tokenizer during development
|
||||
# but I don't have time to dive into them right now;
|
||||
# using a dedicated tokenizer name so that we can fix later without re-converting GGUF
|
||||
self.gguf_writer.add_tokenizer_model("gemma4")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
@@ -7475,7 +7557,7 @@ class Gemma4Model(Gemma3Model):
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
self.gguf_writer.add_add_bos_token(False) # already added via the chat template
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
@@ -8757,7 +8839,7 @@ class DeepseekV2Model(TextModel):
|
||||
# Build merges list using the approach similar to HunYuanMoE
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.model._mergeable_ranks
|
||||
mergeable_ranks = tokenizer.model._mergeable_ranks # ty: ignore[unresolved-attribute]
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[QwenModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
@@ -8768,7 +8850,7 @@ class DeepseekV2Model(TextModel):
|
||||
|
||||
# Build token list
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
special_tokens = tokenizer.special_tokens
|
||||
special_tokens = tokenizer.special_tokens # ty: ignore[unresolved-attribute]
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
@@ -9739,10 +9821,10 @@ class Glm4Model(TextModel):
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
@@ -9970,12 +10052,12 @@ class ChatGLMModel(TextModel):
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size
|
||||
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab())) # ty: ignore[unresolved-attribute]
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size # ty: ignore[unresolved-attribute]
|
||||
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
|
||||
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
|
||||
for token_id in range(vocab_size):
|
||||
piece = tokenizer._convert_id_to_token(token_id)
|
||||
piece = tokenizer._convert_id_to_token(token_id) # ty: ignore[unresolved-attribute]
|
||||
if token_id == 0:
|
||||
piece = "<unk>"
|
||||
elif token_id == 1:
|
||||
@@ -9983,17 +10065,17 @@ class ChatGLMModel(TextModel):
|
||||
elif token_id == 2:
|
||||
piece = "<eos>"
|
||||
|
||||
text = piece.encode("utf-8")
|
||||
text = piece.encode("utf-8") # ty: ignore[unresolved-attribute]
|
||||
score = 0.0
|
||||
# Referencing the tokenizer Python implementation(https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py),
|
||||
# it is only valid if it is less than tokenizer.tokenizer.sp_model.vocab_size()
|
||||
if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size():
|
||||
score = tokenizer.tokenizer.sp_model.get_score(token_id)
|
||||
if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size(): # ty: ignore[unresolved-attribute, invalid-argument-type]
|
||||
score = tokenizer.tokenizer.sp_model.get_score(token_id) # ty: ignore[unresolved-attribute]
|
||||
|
||||
if token_id >= tokenizer.tokenizer.sp_model.vocab_size():
|
||||
if token_id >= tokenizer.tokenizer.sp_model.vocab_size(): # ty: ignore[unresolved-attribute]
|
||||
if piece in special_tokens:
|
||||
toktype = SentencePieceTokenTypes.CONTROL
|
||||
elif len(piece) == 0:
|
||||
elif len(piece) == 0: # ty: ignore[invalid-argument-type]
|
||||
text = f"[PAD{token_id}]".encode("utf-8")
|
||||
toktype = SentencePieceTokenTypes.UNUSED
|
||||
else:
|
||||
@@ -10004,13 +10086,13 @@ class ChatGLMModel(TextModel):
|
||||
continue
|
||||
|
||||
toktype = SentencePieceTokenTypes.NORMAL
|
||||
if tokenizer.tokenizer.sp_model.is_unknown(token_id):
|
||||
if tokenizer.tokenizer.sp_model.is_unknown(token_id): # ty: ignore[unresolved-attribute]
|
||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
||||
elif tokenizer.tokenizer.sp_model.is_control(token_id):
|
||||
elif tokenizer.tokenizer.sp_model.is_control(token_id): # ty: ignore[unresolved-attribute]
|
||||
toktype = SentencePieceTokenTypes.CONTROL
|
||||
elif tokenizer.tokenizer.sp_model.is_unused(token_id):
|
||||
elif tokenizer.tokenizer.sp_model.is_unused(token_id): # ty: ignore[unresolved-attribute]
|
||||
toktype = SentencePieceTokenTypes.UNUSED
|
||||
elif tokenizer.tokenizer.sp_model.is_byte(token_id):
|
||||
elif tokenizer.tokenizer.sp_model.is_byte(token_id): # ty: ignore[unresolved-attribute]
|
||||
toktype = SentencePieceTokenTypes.BYTE
|
||||
|
||||
tokens.append(text)
|
||||
@@ -10030,7 +10112,7 @@ class ChatGLMModel(TextModel):
|
||||
|
||||
@staticmethod
|
||||
def token_bytes_to_string(b):
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode # ty: ignore[unresolved-import]
|
||||
byte_encoder = bytes_to_unicode()
|
||||
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
|
||||
|
||||
@@ -10064,7 +10146,7 @@ class ChatGLMModel(TextModel):
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
vocab_size = hparams.get("padded_vocab_size",hparams["vocab_size"])
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size # ty: ignore[unresolved-attribute]
|
||||
|
||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
@@ -10073,10 +10155,10 @@ class ChatGLMModel(TextModel):
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
# only add special tokens when they were not already loaded from config.json
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # ty: ignore[unresolved-attribute]
|
||||
# this one is usually not in config.json anyway
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
@@ -11342,7 +11424,7 @@ class HunYuanMoEModel(TextModel):
|
||||
# 2. Reverse-engineer the merges list from mergeable_ranks
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.mergeable_ranks
|
||||
mergeable_ranks = tokenizer.mergeable_ranks # ty: ignore[unresolved-attribute]
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[QwenModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
@@ -11353,8 +11435,8 @@ class HunYuanMoEModel(TextModel):
|
||||
|
||||
# 3. Generate the tokens and toktypes lists
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
assert tokenizer.vocab_size == vocab_size
|
||||
special_tokens = tokenizer.special_tokens
|
||||
assert tokenizer.vocab_size == vocab_size # ty: ignore[unresolved-attribute]
|
||||
special_tokens = tokenizer.special_tokens # ty: ignore[unresolved-attribute]
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
@@ -11524,13 +11606,50 @@ class LLaDAMoEModel(TextModel):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("HunYuanDenseV1ForCausalLM")
|
||||
@ModelBase.register("HunYuanDenseV1ForCausalLM", "HunYuanVLForConditionalGeneration")
|
||||
class HunYuanModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE
|
||||
|
||||
def _get_eod_token_id(self) -> int | None:
|
||||
"""Get the actual end-of-generation token from config (eod_token_id)."""
|
||||
return self.hparams.get("eod_token_id")
|
||||
|
||||
def _get_eot_token_id(self) -> int | None:
|
||||
"""Get the end-of-turn token from generation_config.json.
|
||||
This is the first entry in eos_token_id when it's a list."""
|
||||
gen_cfg_path = self.dir_model / "generation_config.json"
|
||||
if gen_cfg_path.is_file():
|
||||
with open(gen_cfg_path, encoding="utf-8") as f:
|
||||
gen_cfg = json.load(f)
|
||||
eos = gen_cfg.get("eos_token_id")
|
||||
if isinstance(eos, list) and len(eos) >= 2:
|
||||
return eos[0]
|
||||
return None
|
||||
|
||||
def _fix_special_tokens(self):
|
||||
"""Fix EOS/EOT tokens that are incorrect in upstream configs."""
|
||||
eod_id = self._get_eod_token_id()
|
||||
if eod_id is not None:
|
||||
self.gguf_writer.add_eos_token_id(eod_id)
|
||||
eot_id = self._get_eot_token_id()
|
||||
if eot_id is not None:
|
||||
self.gguf_writer.add_eot_token_id(eot_id)
|
||||
|
||||
def set_vocab(self):
|
||||
if (self.dir_model / "tokenizer.json").is_file():
|
||||
self._set_vocab_gpt2()
|
||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
# HunyuanOCR has pad_token_id=-1 in config.json; exclude pad from SpecialVocab
|
||||
token_types = None
|
||||
if (self.hparams.get("pad_token_id") or 0) < 0:
|
||||
token_types = ('bos', 'eos', 'unk', 'sep', 'cls', 'mask')
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True, special_token_types=token_types)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
self._fix_special_tokens()
|
||||
else:
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
@@ -11541,7 +11660,7 @@ class HunYuanModel(TextModel):
|
||||
# 2. Reverse-engineer the merges list from mergeable_ranks
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.mergeable_ranks
|
||||
mergeable_ranks = tokenizer.mergeable_ranks # ty: ignore[unresolved-attribute]
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[QwenModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
@@ -11552,8 +11671,8 @@ class HunYuanModel(TextModel):
|
||||
|
||||
# 3. Generate the tokens and toktypes lists
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
assert tokenizer.vocab_size == vocab_size
|
||||
special_tokens = tokenizer.special_tokens
|
||||
assert tokenizer.vocab_size == vocab_size # ty: ignore[unresolved-attribute]
|
||||
special_tokens = tokenizer.special_tokens # ty: ignore[unresolved-attribute]
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
@@ -11582,13 +11701,18 @@ class HunYuanModel(TextModel):
|
||||
# FIX for BOS token: Overwrite incorrect id read from config.json
|
||||
if self.hparams['hidden_size'] == 4096:
|
||||
self.gguf_writer.add_bos_token_id(127958) # only for 7b dense, fix <|bos|> token
|
||||
self._fix_special_tokens()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
# HunyuanOCR has num_experts=1 which is not MoE, prevent parent from writing it
|
||||
saved_num_experts = self.hparams.pop("num_experts", None)
|
||||
super().set_gguf_parameters()
|
||||
if saved_num_experts is not None and saved_num_experts > 1:
|
||||
self.hparams["num_experts"] = saved_num_experts
|
||||
hparams = self.hparams
|
||||
|
||||
# Rope
|
||||
if self.rope_parameters.get("rope_type") == "dynamic":
|
||||
if self.rope_parameters.get("rope_type") in ("dynamic", "xdrope"):
|
||||
# HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
# 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
|
||||
alpha = self.rope_parameters.get("alpha", 50)
|
||||
@@ -11598,13 +11722,14 @@ class HunYuanModel(TextModel):
|
||||
self.gguf_writer.add_rope_freq_base(scaled_base)
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
self.gguf_writer.add_rope_scaling_factor(1)
|
||||
# There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length
|
||||
self.gguf_writer.add_context_length(256 * 1024) # 256k context length
|
||||
if self.rope_parameters.get("rope_type") == "dynamic":
|
||||
# There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length
|
||||
self.gguf_writer.add_context_length(256 * 1024) # 256k context length
|
||||
|
||||
# if any of our assumptions about the values are wrong, something has changed and this may need to be updated
|
||||
assert base == 10000.0 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \
|
||||
"HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
|
||||
# if any of our assumptions about the values are wrong, something has changed and this may need to be updated
|
||||
assert base == 10000.0 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \
|
||||
"HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name == "lm_head.weight":
|
||||
@@ -11612,9 +11737,48 @@ class HunYuanModel(TextModel):
|
||||
logger.info("Skipping tied output layer 'lm_head.weight'")
|
||||
return
|
||||
|
||||
# skip vision tensors for HunyuanVL models
|
||||
if name.startswith("vit."):
|
||||
return
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("HunYuanVLForConditionalGeneration")
|
||||
class HunyuanOCRVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
# HunyuanOCR uses max_image_size instead of image_size
|
||||
if "image_size" not in self.hparams_vision:
|
||||
self.hparams_vision["image_size"] = self.hparams_vision.get("max_image_size", 2048)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
assert self.hparams_vision is not None
|
||||
hparams = self.hparams_vision
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANOCR)
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-5))
|
||||
self.gguf_writer.add_vision_spatial_merge_size(hparams.get("spatial_merge_size", 2))
|
||||
self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"])
|
||||
self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if not name.startswith("vit."):
|
||||
return # skip text tensors
|
||||
# strip CLS token (row 0) from position embeddings so resize_position_embeddings works
|
||||
if "position_embedding" in name:
|
||||
data_torch = data_torch[1:] # [n_patches+1, n_embd] -> [n_patches, n_embd]
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
# force conv weights to F32 or F16 to avoid BF16 IM2COL issues on Metal
|
||||
if ("mm.0." in new_name or "mm.2." in new_name) and new_name.endswith(".weight"):
|
||||
return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
|
||||
@ModelBase.register("SmolLM3ForCausalLM")
|
||||
class SmolLM3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.SMOLLM3
|
||||
@@ -11739,10 +11903,8 @@ class LFM2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.LFM2
|
||||
|
||||
def _add_feed_forward_length(self):
|
||||
ff_dim = self.hparams["block_ff_dim"]
|
||||
|
||||
ff_dim = self.find_hparam(["block_ff_dim", "intermediate_size"])
|
||||
auto_adjust_ff_dim = self.hparams["block_auto_adjust_ff_dim"]
|
||||
ff_dim = self.hparams["block_ff_dim"]
|
||||
ffn_dim_multiplier = self.hparams["block_ffn_dim_multiplier"]
|
||||
multiple_of = self.hparams["block_multiple_of"]
|
||||
|
||||
@@ -12658,13 +12820,44 @@ class SolarOpenModel(Glm4MoeModel):
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<unk>"])
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|startoftext|>"])
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<unk>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|startoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
|
||||
@ModelBase.register("DotsOCRForCausalLM")
|
||||
class DotsOCRVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
self.hparams_vision["image_size"] = 0 # dynamic resolution
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.DOTSOCR)
|
||||
self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"])
|
||||
self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"])
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.find_vparam(["rms_norm_eps"]))
|
||||
self.gguf_writer.add_vision_projector_scale_factor(self.find_vparam(["spatial_merge_size"]))
|
||||
self.gguf_writer.add_vision_use_silu(True)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.startswith("vision_tower."):
|
||||
if "vision_tower.blocks." in name and ".mlp." in name:
|
||||
# note: to avoid naming conflicts in tensor_mapping.py, we need to handle FFN renaming here
|
||||
# x = F.silu(self.fc1(x)) * self.fc3(x)
|
||||
# x = self.fc2(x)
|
||||
# fc1 -> gate, fc2 -> down, fc3 -> up
|
||||
# mapping original names to Qwen2.5 naming scheme
|
||||
name = name.replace("vision_tower.blocks.", "visual.blocks.")
|
||||
name = name.replace(".fc1", ".gate_proj")
|
||||
name = name.replace(".fc2", ".down_proj")
|
||||
name = name.replace(".fc3", ".up_proj")
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
@@ -12917,6 +13110,12 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
|
||||
# For non-hf Mamba and Mamba2 models
|
||||
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
|
||||
|
||||
# Step3-VL keeps text config under text_config but uses a custom top-level architecture.
|
||||
# For text conversion we route to a dedicated text-only class.
|
||||
# TODO: refactor this later to avoid adding exception here
|
||||
if model_type == ModelType.TEXT and arch == "StepVLForConditionalGeneration":
|
||||
return arch
|
||||
|
||||
# if "architectures" is found in the sub-config, use that instead
|
||||
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
|
||||
arch = text_config["architectures"][0]
|
||||
|
||||
@@ -296,7 +296,7 @@ for model in [*pre_computed_hashes, *all_models]:
|
||||
except Exception as e:
|
||||
raise OSError(f"Error loading tokenizer for model {name}.") from e
|
||||
|
||||
chktok = tokenizer.encode(CHK_TXT)
|
||||
chktok = tokenizer.encode(CHK_TXT) # ty: ignore[unresolved-attribute]
|
||||
chkhsh = sha256(str(chktok).encode()).hexdigest()
|
||||
|
||||
logger.info(f"model: {name}")
|
||||
@@ -468,7 +468,7 @@ for model in models:
|
||||
|
||||
with open(f"models/ggml-vocab-{name}.gguf.out", "w") as f:
|
||||
for text in tests:
|
||||
res = tokenizer.encode(text, add_special_tokens=False)
|
||||
res = tokenizer.encode(text, add_special_tokens=False) # ty: ignore[unresolved-attribute]
|
||||
for r in res:
|
||||
f.write(f" {r}")
|
||||
f.write("\n")
|
||||
|
||||
@@ -402,7 +402,7 @@ if __name__ == '__main__':
|
||||
# the invocation string includes the "<|start_of_turn|>"
|
||||
# token, but the adapters themselves were trained to
|
||||
# activate _after_ that first token, so we drop it here.
|
||||
alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:]
|
||||
alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:] # ty: ignore[call-non-callable]
|
||||
if alora_invocation_tokens:
|
||||
logger.debug("GGUF KV: %s = %s", gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, alora_invocation_tokens)
|
||||
self.gguf_writer.add_key_value(
|
||||
|
||||
@@ -57,13 +57,14 @@ ZenDNN is optimized for AMD EPYC™ processors and AMD Ryzen™ processors based
|
||||
|
||||
## Supported Operations
|
||||
|
||||
The ZenDNN backend currently accelerates **matrix multiplication (MUL_MAT)** operations only. Other operations are handled by the standard CPU backend.
|
||||
The ZenDNN backend accelerates **matrix multiplication (MUL_MAT)** and **expert-based matrix multiplication (MUL_MAT_ID)** operations. Other operations are handled by the standard CPU backend.
|
||||
|
||||
| Operation | Status | Notes |
|
||||
|:-------------|:-------:|:----------------------------------------------:|
|
||||
| MUL_MAT | Support | Accelerated via ZenDNN LowOHA MatMul |
|
||||
| MUL_MAT_ID | Support | Accelerated via ZenDNN LowOHA MatMul (MoE) |
|
||||
|
||||
*Note:* Since only MUL_MAT is accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs).
|
||||
*Note:* Since MUL_MAT and MUL_MAT_ID are accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs and Mixture-of-Experts models).
|
||||
|
||||
## DataType Supports
|
||||
|
||||
@@ -181,7 +182,7 @@ For detailed profiling and logging options, refer to the [ZenDNN Logging Documen
|
||||
|
||||
## Known Issues
|
||||
|
||||
- **Limited operation support**: Currently only matrix multiplication (MUL_MAT) is accelerated via ZenDNN. Other operations fall back to the standard CPU backend.
|
||||
- **Limited operation support**: Currently matrix multiplication (MUL_MAT) and expert-based matrix multiplication (MUL_MAT_ID) are accelerated via ZenDNN. Other operations fall back to the standard CPU backend. Future updates may expand supported operations.
|
||||
- **BF16 support**: BF16 operations require AMD Zen 4 or Zen 5 architecture (EPYC 9004/9005 series). On older CPUs, operations will use FP32.
|
||||
- **NUMA awareness**: For multi-socket systems, manual NUMA binding may be required for optimal performance.
|
||||
|
||||
@@ -216,4 +217,4 @@ Please add the **[ZenDNN]** prefix/tag in issues/PRs titles to help the ZenDNN-t
|
||||
|
||||
## TODO
|
||||
|
||||
- Expand operation support beyond MUL_MAT (attention operations, activations, etc.)
|
||||
- Expand operation support beyond MUL_MAT and MUL_MAT_ID (attention operations, activations, etc.)
|
||||
|
||||
@@ -389,7 +389,7 @@ You can download it from your Linux distro's package manager or from here: [ROCm
|
||||
|
||||
|
||||
The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used.
|
||||
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3.
|
||||
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3. Note that [`HSA_OVERRIDE_GFX_VERSION`] is [not supported on Windows](https://github.com/ROCm/ROCm/issues/2654)
|
||||
|
||||
### Unified Memory
|
||||
|
||||
@@ -741,7 +741,7 @@ cmake --build build --config Release
|
||||
|
||||
WebGPU allows cross-platform access to the GPU from supported browsers. We utilize [Emscripten](https://emscripten.org/) to compile ggml's WebGPU backend to WebAssembly. Emscripten does not officially support WebGPU bindings yet, but Dawn currently maintains its own WebGPU bindings called emdawnwebgpu.
|
||||
|
||||
Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/src/emdawnwebgpu/) to download or build the emdawnwebgpu package (Note that it might be safer to build the emdawbwebgpu package locally, so that it stays in sync with the version of Dawn you have installed above). When building using CMake, the path to the emdawnwebgpu port file needs to be set with the flag `EMDAWNWEBGPU_DIR`.
|
||||
Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/src/emdawnwebgpu/) to download or build the emdawnwebgpu package (Note that it might be safer to build the emdawnwebgpu package locally, so that it stays in sync with the version of Dawn you have installed above). When building using CMake, the path to the emdawnwebgpu port file needs to be set with the flag `EMDAWNWEBGPU_DIR`.
|
||||
|
||||
## IBM Z & LinuxONE
|
||||
|
||||
|
||||
@@ -37,6 +37,8 @@ llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload
|
||||
> - PaddleOCR-VL: https://github.com/ggml-org/llama.cpp/pull/18825
|
||||
> - GLM-OCR: https://github.com/ggml-org/llama.cpp/pull/19677
|
||||
> - Deepseek-OCR: https://github.com/ggml-org/llama.cpp/pull/17400
|
||||
> - Dots.OCR: https://github.com/ggml-org/llama.cpp/pull/17575
|
||||
> - HunyuanOCR: https://github.com/ggml-org/llama.cpp/pull/21395
|
||||
|
||||
## Pre-quantized models
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ Legend:
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
|
||||
1145
docs/ops/WebGPU.csv
1145
docs/ops/WebGPU.csv
File diff suppressed because it is too large
Load Diff
9986
docs/ops/ZenDNN.csv
9986
docs/ops/ZenDNN.csv
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,7 @@
|
||||
#include <vector>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <optional>
|
||||
#include <regex>
|
||||
|
||||
static void print_usage(int /*argc*/, char ** argv) {
|
||||
@@ -222,7 +223,10 @@ int main(int argc, char ** argv) {
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
base_callback_data cb_data(params, params.tensor_filter);
|
||||
std::optional<base_callback_data> cb_data;
|
||||
if (!params.save_logits) {
|
||||
cb_data.emplace(params, params.tensor_filter);
|
||||
}
|
||||
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
|
||||
@@ -53,10 +53,10 @@ model_name = os.path.basename(model_path)
|
||||
print(f"Model name: {model_name}")
|
||||
|
||||
prompt = "Hello world today"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids # ty: ignore[call-non-callable]
|
||||
print(f"Input tokens: {input_ids}")
|
||||
print(f"Input text: {repr(prompt)}")
|
||||
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
|
||||
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}") # ty: ignore[unresolved-attribute]
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids, output_hidden_states=True)
|
||||
@@ -92,7 +92,7 @@ with torch.no_grad():
|
||||
|
||||
# Print embeddings per token in the requested format
|
||||
print("\nToken embeddings:")
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) # ty: ignore[unresolved-attribute]
|
||||
for i, embedding in enumerate(token_embeddings):
|
||||
# Format: show first few values, ..., then last few values
|
||||
if len(embedding) > 10:
|
||||
|
||||
@@ -207,8 +207,8 @@ def main():
|
||||
else:
|
||||
model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
|
||||
|
||||
encoded = tokenizer(prompt, return_tensors="pt")
|
||||
tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
|
||||
encoded = tokenizer(prompt, return_tensors="pt") # ty: ignore[call-non-callable]
|
||||
tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0]) # ty: ignore[unresolved-attribute]
|
||||
n_tokens = len(tokens)
|
||||
print(f"n_tokens: {n_tokens}");
|
||||
print(f"hidden_size: {model.config.hidden_size}")
|
||||
|
||||
@@ -428,7 +428,8 @@ extern "C" {
|
||||
// GGML_TYPE_IQ4_NL_8_8 = 38,
|
||||
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
|
||||
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
|
||||
GGML_TYPE_COUNT = 41,
|
||||
GGML_TYPE_Q1_0 = 41,
|
||||
GGML_TYPE_COUNT = 42,
|
||||
};
|
||||
|
||||
// precision
|
||||
@@ -465,6 +466,7 @@ extern "C" {
|
||||
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_Q1_0 = 27, // except 1d tensors
|
||||
};
|
||||
|
||||
// available tensor operations:
|
||||
@@ -900,15 +902,17 @@ extern "C" {
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * ids);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_add1(
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_add1(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_tensor * b),
|
||||
"use ggml_add instead");
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_add1_inplace(
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_add1_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_tensor * b),
|
||||
"use ggml_add_inplace instead");
|
||||
|
||||
// dst = a
|
||||
// view(dst, nb1, nb2, nb3, offset) += b
|
||||
|
||||
@@ -93,6 +93,10 @@ typedef sycl::half2 ggml_half2;
|
||||
// QR = QK / number of values before dequantization
|
||||
// QI = number of 32 bit integers before dequantization
|
||||
|
||||
#define QI1_0 (QK1_0 / 32)
|
||||
#define QR1_0 1
|
||||
|
||||
|
||||
#define QI4_0 (QK4_0 / (4 * QR4_0))
|
||||
#define QR4_0 2
|
||||
|
||||
@@ -170,6 +174,13 @@ typedef sycl::half2 ggml_half2;
|
||||
#define GGML_EXTENSION __extension__
|
||||
#endif // _MSC_VER
|
||||
|
||||
#define QK1_0 128
|
||||
typedef struct {
|
||||
ggml_half d; // delta
|
||||
uint8_t qs[QK1_0 / 8]; // bits / quants
|
||||
} block_q1_0;
|
||||
static_assert(sizeof(block_q1_0) == sizeof(ggml_half) + QK1_0 / 8, "wrong q1_0 block size/padding");
|
||||
|
||||
#define QK4_0 32
|
||||
typedef struct {
|
||||
ggml_half d; // delta
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
|
||||
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
||||
@@ -82,6 +83,7 @@
|
||||
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
||||
// quants.c
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
@@ -112,6 +114,7 @@
|
||||
// quants.c
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||
@@ -160,6 +163,7 @@
|
||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
@@ -200,6 +204,7 @@
|
||||
#elif defined(__riscv)
|
||||
// quants.c
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
@@ -240,6 +245,7 @@
|
||||
// quants.c
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
||||
@@ -303,6 +309,7 @@
|
||||
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
||||
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
|
||||
@@ -137,6 +137,109 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
|
||||
|
||||
//===================================== Dot products =================================
|
||||
|
||||
void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK1_0; // 128
|
||||
const int nb = n / qk;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q1_0 * GGML_RESTRICT x = vx;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
#if defined(__ARM_NEON)
|
||||
float32x4_t sumv = vdupq_n_f32(0.0f);
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d);
|
||||
|
||||
// Process 4 Q8_0 blocks (each has 32 elements)
|
||||
for (int k = 0; k < 4; k++) {
|
||||
const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k];
|
||||
const float d1 = GGML_CPU_FP16_TO_FP32(yb->d);
|
||||
|
||||
// Get the 4 bytes of bits for this Q8_0 block (32 bits = 4 bytes)
|
||||
// Bits are at offset k*4 bytes in x[i].qs
|
||||
const uint8_t * bits = &x[i].qs[k * 4];
|
||||
|
||||
// Load 32 int8 values from y
|
||||
const int8x16_t y0 = vld1q_s8(yb->qs);
|
||||
const int8x16_t y1 = vld1q_s8(yb->qs + 16);
|
||||
|
||||
// Byte 0-1: bits for y0[0..15]
|
||||
const uint64_t expand0 = table_b2b_0[bits[0]];
|
||||
const uint64_t expand1 = table_b2b_0[bits[1]];
|
||||
// Byte 2-3: bits for y1[0..15]
|
||||
const uint64_t expand2 = table_b2b_0[bits[2]];
|
||||
const uint64_t expand3 = table_b2b_0[bits[3]];
|
||||
|
||||
// Build the sign vectors by reinterpreting the table values
|
||||
uint8x8_t e0 = vcreate_u8(expand0);
|
||||
uint8x8_t e1 = vcreate_u8(expand1);
|
||||
uint8x8_t e2 = vcreate_u8(expand2);
|
||||
uint8x8_t e3 = vcreate_u8(expand3);
|
||||
|
||||
// Shift right by 4 to get 0 or 1
|
||||
int8x8_t s0 = vreinterpret_s8_u8(vshr_n_u8(e0, 4));
|
||||
int8x8_t s1 = vreinterpret_s8_u8(vshr_n_u8(e1, 4));
|
||||
int8x8_t s2 = vreinterpret_s8_u8(vshr_n_u8(e2, 4));
|
||||
int8x8_t s3 = vreinterpret_s8_u8(vshr_n_u8(e3, 4));
|
||||
|
||||
// Convert 0/1 to -1/+1: sign = 2*val - 1
|
||||
int8x8_t one = vdup_n_s8(1);
|
||||
s0 = vsub_s8(vadd_s8(s0, s0), one); // 2*s0 - 1
|
||||
s1 = vsub_s8(vadd_s8(s1, s1), one);
|
||||
s2 = vsub_s8(vadd_s8(s2, s2), one);
|
||||
s3 = vsub_s8(vadd_s8(s3, s3), one);
|
||||
|
||||
// Combine into 16-element vectors
|
||||
int8x16_t signs0 = vcombine_s8(s0, s1);
|
||||
int8x16_t signs1 = vcombine_s8(s2, s3);
|
||||
|
||||
// Multiply signs with y values and accumulate
|
||||
// dot(signs, y) where signs are +1/-1
|
||||
int32x4_t p0 = ggml_vdotq_s32(vdupq_n_s32(0), signs0, y0);
|
||||
int32x4_t p1 = ggml_vdotq_s32(p0, signs1, y1);
|
||||
|
||||
// Scale by d1 and accumulate
|
||||
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(p1), d0 * d1);
|
||||
}
|
||||
}
|
||||
|
||||
sumf = vaddvq_f32(sumv);
|
||||
#else
|
||||
// Scalar fallback
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d0 = GGML_FP16_TO_FP32(x[i].d);
|
||||
|
||||
// Process 4 Q8_0 blocks
|
||||
for (int k = 0; k < 4; k++) {
|
||||
const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d);
|
||||
|
||||
int sumi = 0;
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
const int bit_index = k * QK8_0 + j;
|
||||
const int byte_index = bit_index / 8;
|
||||
const int bit_offset = bit_index % 8;
|
||||
|
||||
const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1;
|
||||
sumi += xi * y[i*4 + k].qs[j];
|
||||
}
|
||||
sumf += d0 * d1 * sumi;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
@@ -2156,4 +2156,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
||||
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -2302,4 +2302,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
||||
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -1463,4 +1463,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
||||
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -1218,4 +1218,3 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -217,6 +217,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_F16,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q1_0] = {
|
||||
.from_float = quantize_row_q1_0,
|
||||
.vec_dot = ggml_vec_dot_q1_0_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q4_0] = {
|
||||
.from_float = quantize_row_q4_0,
|
||||
.vec_dot = ggml_vec_dot_q4_0_q8_0,
|
||||
|
||||
@@ -4829,6 +4829,7 @@ void ggml_compute_forward_get_rows(
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -5554,6 +5555,7 @@ void ggml_compute_forward_clamp(
|
||||
ggml_compute_forward_clamp_f16(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
|
||||
@@ -22,6 +22,10 @@
|
||||
|
||||
#define UNUSED GGML_UNUSED
|
||||
|
||||
void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
|
||||
quantize_row_q1_0_ref(x, y, k);
|
||||
}
|
||||
|
||||
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
|
||||
quantize_row_q4_0_ref(x, y, k);
|
||||
}
|
||||
@@ -116,6 +120,51 @@ void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI
|
||||
|
||||
//===================================== Dot products =================================
|
||||
|
||||
void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK1_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q1_0 * GGML_RESTRICT x = vx;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
float sumf = 0.0;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d0 = GGML_FP16_TO_FP32(x[i].d);
|
||||
|
||||
float sumi = 0.0f;
|
||||
|
||||
for (int k = 0; k < 4; k++) {
|
||||
const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d);
|
||||
|
||||
int sumi_block = 0;
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
const int bit_index = k * QK8_0 + j;
|
||||
const int byte_index = bit_index / 8;
|
||||
const int bit_offset = bit_index % 8;
|
||||
|
||||
const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1;
|
||||
sumi_block += xi * y[i*4 + k].qs[j];
|
||||
}
|
||||
|
||||
sumi += d1 * sumi_block;
|
||||
}
|
||||
|
||||
sumf += d0 * sumi;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
|
||||
void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
@@ -12,6 +12,7 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
// Quantization
|
||||
void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
@@ -36,6 +37,7 @@ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
|
||||
void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
// Dot product
|
||||
void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
@@ -68,6 +70,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
void quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
@@ -60,24 +60,24 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
offset_iterator, offset_iterator + 1, stream);
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
|
||||
stream);
|
||||
stream));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,22 +86,22 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, stream);
|
||||
offset_iterator + 1, stream));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@
|
||||
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
|
||||
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
|
||||
#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
|
||||
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
|
||||
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x90a) // MI210 (gfx90a), minimum acc register renaming
|
||||
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
|
||||
|
||||
// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32
|
||||
@@ -1157,19 +1157,6 @@ struct ggml_tensor_extra_gpu {
|
||||
#define USE_CUDA_GRAPH
|
||||
#endif
|
||||
|
||||
struct ggml_cuda_graph_node_properties {
|
||||
void * node_data;
|
||||
ggml_op node_op;
|
||||
enum ggml_type node_type;
|
||||
int32_t flags;
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
size_t nb[GGML_MAX_DIMS];
|
||||
void * src_data[GGML_MAX_SRC];
|
||||
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
};
|
||||
|
||||
static_assert(std::is_trivial<ggml_cuda_graph_node_properties>::value, "ggml_cuda_graph_node_properties must be trivial");
|
||||
|
||||
struct ggml_cuda_graph {
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
~ggml_cuda_graph() {
|
||||
@@ -1186,13 +1173,11 @@ struct ggml_cuda_graph {
|
||||
std::vector<cudaGraphNode_t> nodes;
|
||||
bool disable_due_to_gpu_arch = false;
|
||||
bool warmup_complete = false;
|
||||
std::vector<ggml_cuda_graph_node_properties> props;
|
||||
|
||||
// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
|
||||
// they properties also have to match in order to be able to safely reuse a CUDA graph
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18583
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/19165
|
||||
std::vector<ggml_cuda_graph_node_properties> extra;
|
||||
struct node_properties {
|
||||
ggml_tensor node;
|
||||
void * node_src_data_ptrs[GGML_MAX_SRC];
|
||||
};
|
||||
std::vector<node_properties> node_props;
|
||||
|
||||
bool is_enabled() const {
|
||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||
|
||||
@@ -676,9 +676,96 @@ static __global__ void flash_attn_mask_to_KV_max(
|
||||
|
||||
template<int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
|
||||
const int ne11, const int ne12, const int nbatch_fa) {
|
||||
static __global__ void flash_attn_stream_k_fixup_uniform(
|
||||
float * __restrict__ dst,
|
||||
const float2 * __restrict__ dst_fixup,
|
||||
const int ne01, const int ne02,
|
||||
const int ne12, const int nblocks_stream_k,
|
||||
const int gqa_ratio,
|
||||
const int blocks_per_tile,
|
||||
const uint3 fd_iter_j_z_ne12,
|
||||
const uint3 fd_iter_j_z,
|
||||
const uint3 fd_iter_j) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int tile_idx = blockIdx.x; // One block per output tile.
|
||||
const int j = blockIdx.y;
|
||||
const int c = blockIdx.z;
|
||||
const int jc = j*ncols2 + c;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks.
|
||||
const int b_first = tile_idx * blocks_per_tile;
|
||||
const int b_last = b_first + blocks_per_tile - 1;
|
||||
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2*2*ncols);
|
||||
|
||||
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
|
||||
const uint2 dm0 = fast_div_modulo(tile_idx, fd_iter_j_z_ne12);
|
||||
const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_j_z);
|
||||
const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_j);
|
||||
|
||||
const int sequence = dm0.x;
|
||||
const int z_KV = dm1.x;
|
||||
const int zt_gqa = dm2.x;
|
||||
const int jt = dm2.y;
|
||||
|
||||
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
||||
|
||||
if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
|
||||
|
||||
// Load the partial result that needs a fixup
|
||||
float dst_val = *dst;
|
||||
float max_val;
|
||||
float rowsum;
|
||||
{
|
||||
const float2 tmp = dst_fixup[b_last*ncols + jc];
|
||||
max_val = tmp.x;
|
||||
rowsum = tmp.y;
|
||||
}
|
||||
|
||||
// Combine with all previous blocks in this tile.
|
||||
for (int bidx = b_last - 1; bidx >= b_first; --bidx) {
|
||||
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
|
||||
|
||||
const float2 tmp = dst_fixup[(nblocks_stream_k + bidx)*ncols + jc];
|
||||
|
||||
const float max_val_new = fmaxf(max_val, tmp.x);
|
||||
|
||||
const float diff_val = max_val - max_val_new;
|
||||
const float diff_add = tmp.x - max_val_new;
|
||||
|
||||
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
|
||||
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
|
||||
|
||||
dst_val = scale_val*dst_val + scale_add*dst_add;
|
||||
rowsum = scale_val*rowsum + scale_add*tmp.y;
|
||||
|
||||
max_val = max_val_new;
|
||||
}
|
||||
|
||||
// Write back final result:
|
||||
*dst = dst_val / rowsum;
|
||||
}
|
||||
|
||||
// General fixup kernel for the case where the number of blocks per tile is not uniform across tiles
|
||||
// (blocks_num.x not a multiple of ntiles_dst)
|
||||
template <int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup_general(
|
||||
float * __restrict__ dst,
|
||||
const float2 * __restrict__ dst_fixup,
|
||||
const int ne01, const int ne02,
|
||||
const int gqa_ratio,
|
||||
const int total_work,
|
||||
const uint3 fd_iter_k_j_z_ne12,
|
||||
const uint3 fd_iter_k_j_z,
|
||||
const uint3 fd_iter_k_j,
|
||||
const uint3 fd_iter_k) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
@@ -689,27 +776,26 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
|
||||
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
|
||||
|
||||
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
||||
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
||||
const int kbc0 = int64_t(bidx0 + 0)*total_work / gridDim.x;
|
||||
const int kbc0_stop = int64_t(bidx0 + 1)*total_work / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
|
||||
const bool wrote_beginning_of_tile = fastmodulo(kbc0, fd_iter_k) == 0;
|
||||
const bool did_not_write_last = fastdiv(kbc0, fd_iter_k) == fastdiv(kbc0_stop, fd_iter_k) && fastmodulo(kbc0_stop, fd_iter_k) != 0;
|
||||
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
|
||||
return;
|
||||
}
|
||||
|
||||
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
|
||||
const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
|
||||
const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
|
||||
const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
|
||||
const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
||||
const uint2 dm0 = fast_div_modulo(kbc0, fd_iter_k_j_z_ne12);
|
||||
const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_k_j_z);
|
||||
const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_k_j);
|
||||
const uint2 dm3 = fast_div_modulo(dm2.y, fd_iter_k);
|
||||
|
||||
const int sequence = dm0.x;
|
||||
const int z_KV = dm1.x;
|
||||
const int zt_gqa = dm2.x;
|
||||
const int jt = dm3.x;
|
||||
|
||||
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
||||
|
||||
@@ -733,10 +819,11 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
|
||||
// Iterate over previous blocks and compute the combined results.
|
||||
// All CUDA blocks that get here must have a previous block that needs a fixup.
|
||||
const int tile_kbc0 = fastdiv(kbc0, fd_iter_k);
|
||||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
||||
const int kbc = int64_t(bidx)*total_work / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
@@ -762,7 +849,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
max_val = max_val_new;
|
||||
|
||||
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
||||
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
|
||||
if (fastmodulo(kbc, fd_iter_k) == 0 || fastdiv(kbc, fd_iter_k) < tile_kbc0) {
|
||||
break;
|
||||
}
|
||||
bidx--;
|
||||
@@ -976,14 +1063,28 @@ void launch_fattn(
|
||||
const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks;
|
||||
const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);
|
||||
|
||||
const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst);
|
||||
|
||||
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
|
||||
|
||||
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst;
|
||||
blocks_num.x = ntiles_dst;
|
||||
blocks_num.y = 1;
|
||||
blocks_num.z = 1;
|
||||
|
||||
if(use_stream_k) {
|
||||
const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst);
|
||||
// Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks (avoids fixup).
|
||||
// Only do this if the occupancy loss from rounding is acceptable.
|
||||
const int nblocks_stream_k_rounded = (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst;
|
||||
const int max_efficiency_loss_percent = 5;
|
||||
const int efficiency_loss_percent = nblocks_stream_k_rounded > 0
|
||||
? 100 * (nblocks_stream_k_raw - nblocks_stream_k_rounded) / nblocks_stream_k_raw
|
||||
: 100;
|
||||
const int nblocks_stream_k = efficiency_loss_percent <= max_efficiency_loss_percent
|
||||
? nblocks_stream_k_rounded
|
||||
: nblocks_stream_k_raw;
|
||||
|
||||
blocks_num.x = nblocks_stream_k;
|
||||
}
|
||||
|
||||
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
|
||||
}
|
||||
@@ -1063,13 +1164,40 @@ void launch_fattn(
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if (stream_k) {
|
||||
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
if ((int)blocks_num.x % ntiles_dst == 0 && (int)blocks_num.x > ntiles_dst) {
|
||||
// Optimized fixup: nblocks_stream_k is a multiple of ntiles_dst, launch one block per tile.
|
||||
const int nblocks_sk = (int)blocks_num.x;
|
||||
const int bpt = nblocks_sk / ntiles_dst;
|
||||
|
||||
const uint3 fd0 = init_fastdiv_values(ntiles_x * ntiles_z_gqa * K->ne[2]);
|
||||
const uint3 fd1 = init_fastdiv_values(ntiles_x * ntiles_z_gqa);
|
||||
const uint3 fd2 = init_fastdiv_values(ntiles_x);
|
||||
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2};
|
||||
|
||||
flash_attn_stream_k_fixup_uniform<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr,
|
||||
Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk,
|
||||
gqa_ratio, bpt, fd0, fd1, fd2);
|
||||
} else if (ntiles_dst % blocks_num.x != 0) {
|
||||
// General fixup for the cases where nblocks_stream_k < ntiles_dst.
|
||||
const int total_work = ntiles_KV * ntiles_dst;
|
||||
|
||||
const uint3 fd_k_j_z_ne12 = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa * K->ne[2]);
|
||||
const uint3 fd_k_j_z = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa);
|
||||
const uint3 fd_k_j = init_fastdiv_values(ntiles_KV * ntiles_x);
|
||||
const uint3 fd_k = init_fastdiv_values(ntiles_KV);
|
||||
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
||||
|
||||
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
||||
flash_attn_stream_k_fixup_general<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
|
||||
((float *) KQV->data, dst_tmp_meta.ptr,
|
||||
Q->ne[1], Q->ne[2], gqa_ratio, total_work,
|
||||
fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k);
|
||||
}
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
|
||||
@@ -82,7 +82,6 @@
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
|
||||
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||
|
||||
@@ -2969,74 +2968,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
return use_cuda_graph;
|
||||
}
|
||||
|
||||
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
|
||||
memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
|
||||
props->node_data = node->data;
|
||||
props->node_op = node->op;
|
||||
props->node_type = node->type;
|
||||
props->flags = node->flags;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
props->ne[i] = node->ne[i];
|
||||
props->nb[i] = node->nb[i];
|
||||
}
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (!node->src[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
props->src_data[i] = node->src[i]->data;
|
||||
}
|
||||
memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||
}
|
||||
|
||||
static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
|
||||
if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node->op != props->node_op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node->type != props->node_type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (node->ne[i] != props->ne[i]) {
|
||||
return false;
|
||||
}
|
||||
if (node->nb[i] != props->nb[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (node->op != GGML_OP_VIEW) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (!node->src[i]) {
|
||||
if (props->src_data[i] != nullptr) {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->src[i]->data != props->src_data[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
|
||||
return cgraph->nodes[0];
|
||||
}
|
||||
@@ -3048,52 +2979,25 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
// Check if the graph size has changed
|
||||
if (graph->props.size() != (size_t)cgraph->n_nodes) {
|
||||
if ((int)graph->node_props.size() != cgraph->n_nodes) {
|
||||
res = true;
|
||||
graph->props.resize(cgraph->n_nodes);
|
||||
graph->node_props.resize(cgraph->n_nodes);
|
||||
}
|
||||
|
||||
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||
// and store properties to allow this comparison for the next token
|
||||
std::unordered_set<ggml_tensor *> seen_node;
|
||||
std::vector<ggml_tensor *> srcs_extra;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
bool props_match = true;
|
||||
ggml_cuda_graph::node_properties prop = {};
|
||||
memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor));
|
||||
|
||||
seen_node.insert(cgraph->nodes[i]);
|
||||
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
|
||||
// if the backend scheduler is making copies of CPU tensors, the src pointers can be the same but with different data, see:
|
||||
// https://github.com/ggml-org/llama.cpp/pull/21472#discussion_r3052235188
|
||||
for (int j = 0; j < GGML_MAX_SRC; ++j) {
|
||||
prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j] ? cgraph->nodes[i]->src[j]->data : nullptr;
|
||||
}
|
||||
if (!props_match) {
|
||||
|
||||
if (!res && memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
|
||||
|
||||
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
|
||||
ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
|
||||
if (src && seen_node.find(src) == seen_node.end()) {
|
||||
srcs_extra.push_back(src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (graph->extra.size() != (size_t) srcs_extra.size()) {
|
||||
res = true;
|
||||
graph->extra.resize(srcs_extra.size());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < srcs_extra.size(); ++i) {
|
||||
bool props_match = true;
|
||||
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
|
||||
}
|
||||
|
||||
if (!props_match) {
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
|
||||
graph->node_props[i] = prop;
|
||||
}
|
||||
|
||||
return res;
|
||||
@@ -3308,6 +3212,71 @@ static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int nod
|
||||
return true;
|
||||
}
|
||||
|
||||
// returns whether the write (out) nodes overwrite the read nodes in operation
|
||||
static bool ggml_cuda_check_fusion_memory_ranges(const ggml_cgraph * cgraph,
|
||||
const int node_idx,
|
||||
const int node_count,
|
||||
const int * out_nodes,
|
||||
const int out_count,
|
||||
const bool is_topk_moe = false) {
|
||||
auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
|
||||
const int64_t a_start = (int64_t) a->data;
|
||||
const int64_t a_end = a_start + ggml_backend_buft_get_alloc_size(a->buffer->buft, a);
|
||||
|
||||
const int64_t b_start = (int64_t) b->data;
|
||||
const int64_t b_end = b_start + ggml_backend_buft_get_alloc_size(b->buffer->buft, b);
|
||||
|
||||
if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
bool is_ok = true;
|
||||
// exception for topk-moe, as each row is read entirely before writing
|
||||
if (ggml_nrows(cgraph->nodes[node_idx]) == 1 && is_topk_moe) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (int i = 0; i < out_count; ++i) {
|
||||
const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];
|
||||
|
||||
for (int j = node_idx; j < node_idx + node_count; ++j) {
|
||||
// Loop over all srcs of all nodes in the fusion. If the src overlaps
|
||||
// the destination and the src is not an intermediate node that's being
|
||||
// elided, then disable fusion.
|
||||
|
||||
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
|
||||
const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];
|
||||
|
||||
if (!src || src->op == GGML_OP_NONE) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (nodes_overlap(dst, src)) {
|
||||
bool found = false;
|
||||
|
||||
for (int k = node_idx; k < j; ++k) {
|
||||
if (cgraph->nodes[k] == src) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
is_ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return is_ok;
|
||||
}
|
||||
|
||||
|
||||
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
||||
int node_idx,
|
||||
std::initializer_list<enum ggml_op> ops,
|
||||
@@ -3337,7 +3306,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
||||
const ggml_tensor * glu = cgraph->nodes[node_idx + 4];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
|
||||
return true;
|
||||
int out_nodes[] = { node_idx + 4 };
|
||||
return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3348,7 +3318,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
||||
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
|
||||
return true;
|
||||
int out_nodes[] = { node_idx + 2 };
|
||||
return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3474,69 +3445,6 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
||||
return false;
|
||||
}
|
||||
|
||||
// returns whether the write (out) nodes overwrite the read nodes in operation
|
||||
static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph,
|
||||
int node_idx,
|
||||
int node_count,
|
||||
int * out_nodes,
|
||||
int out_count) {
|
||||
auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
|
||||
const int64_t a_start = (int64_t) a->data;
|
||||
const int64_t a_end = a_start + ggml_nbytes(a);
|
||||
|
||||
const int64_t b_start = (int64_t) b->data;
|
||||
const int64_t b_end = b_start + ggml_nbytes(b);
|
||||
|
||||
if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
bool is_ok = true;
|
||||
// for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok
|
||||
if (ggml_nrows(cgraph->nodes[node_idx]) == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (int i = 0; i < out_count; ++i) {
|
||||
const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];
|
||||
|
||||
for (int j = node_idx; j < node_idx + node_count; ++j) {
|
||||
// Loop over all srcs of all nodes in the fusion. If the src overlaps
|
||||
// the destination and the src is not an intermediate node that's being
|
||||
// elided, then disable fusion.
|
||||
|
||||
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
|
||||
const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];
|
||||
|
||||
if (!src || src->op == GGML_OP_NONE) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (nodes_overlap(dst, src)) {
|
||||
bool found = false;
|
||||
|
||||
for (int k = node_idx; k < j; ++k) {
|
||||
if (cgraph->nodes[k] == src) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
is_ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return is_ok;
|
||||
}
|
||||
|
||||
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
|
||||
bool graph_evaluated_or_captured = false;
|
||||
|
||||
@@ -3734,7 +3642,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
|
||||
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
|
||||
ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
|
||||
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
|
||||
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) {
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
|
||||
i += ops.size() - 1;
|
||||
continue;
|
||||
@@ -3750,7 +3658,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
int out_nodes[2] = { i + 1, i + 5 };
|
||||
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
|
||||
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
|
||||
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
|
||||
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) {
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
|
||||
i += ops.size() - 1;
|
||||
continue;
|
||||
|
||||
@@ -386,17 +386,25 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
|
||||
|
||||
int u[2*VDR_Q4_0_Q8_1_MMQ];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
|
||||
u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
|
||||
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
|
||||
constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int mcpy_int = max_cpy / sizeof(int);
|
||||
static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ");
|
||||
|
||||
int tmp0[4], tmp1[4];
|
||||
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) {
|
||||
ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] );
|
||||
ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0 + l0 * mcpy_int]);
|
||||
}
|
||||
|
||||
u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
|
||||
u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
|
||||
|
||||
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
|
||||
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
|
||||
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
||||
@@ -489,17 +497,25 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
|
||||
|
||||
int u[2*VDR_Q4_1_Q8_1_MMQ];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
|
||||
u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
|
||||
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
|
||||
constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int mcpy_int = max_cpy / sizeof(int);
|
||||
static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ");
|
||||
|
||||
int tmp0[4], tmp1[4];
|
||||
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) {
|
||||
ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] );
|
||||
ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1 + l0 * mcpy_int]);
|
||||
}
|
||||
|
||||
u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
|
||||
u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
|
||||
|
||||
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
|
||||
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
|
||||
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
||||
@@ -4170,3 +4186,4 @@ void ggml_cuda_op_mul_mat_q(
|
||||
const int64_t src1_padded_row_size, cudaStream_t stream);
|
||||
|
||||
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);
|
||||
|
||||
|
||||
@@ -25,14 +25,14 @@ static void top_k_cub(ggml_cuda_pool & pool,
|
||||
auto indexes_in = cuda::make_counting_iterator(0);
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
|
||||
env);
|
||||
CUDA_CHECK(DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
|
||||
env));
|
||||
|
||||
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
||||
void * d_temp_storage = temp_storage_alloc.get();
|
||||
|
||||
DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
|
||||
ncols, k, env);
|
||||
CUDA_CHECK(DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
|
||||
ncols, k, env));
|
||||
}
|
||||
|
||||
#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
|
||||
|
||||
@@ -164,6 +164,12 @@ static void quicksort_values_indices_desc(float * values, int32_t * indices, int
|
||||
if (i < right) quicksort_values_indices_desc(values, indices, i, right);
|
||||
}
|
||||
|
||||
// LUT for ramp initialization of argsort output (first 32 members)
|
||||
int32_t argosrt_ramp_lut[32] __attribute__((aligned(VLEN))) = {
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
|
||||
};
|
||||
|
||||
static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
|
||||
struct htp_ops_context * octx = actx->octx;
|
||||
@@ -205,8 +211,12 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
|
||||
// Padded to 128 bytes.
|
||||
|
||||
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
|
||||
size_t num_vec_ind_values = hmx_ceil_div(ne00, VLEN/(sizeof(int32_t)));
|
||||
float * values_buf = (float *) spad;
|
||||
int32_t * indices_buf = (int32_t *) (spad + values_size);
|
||||
HVX_Vector * indices_buf_vec = (HVX_Vector *) (spad + values_size);
|
||||
const HVX_Vector ind_init_vec = *(HVX_Vector *)argosrt_ramp_lut;
|
||||
const HVX_Vector ind_diff_vec = Q6_V_vsplat_R(32);
|
||||
|
||||
for (uint32_t r = start_row; r < end_row; r++) {
|
||||
uint32_t src_offset = r * nb01;
|
||||
@@ -218,9 +228,11 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
|
||||
hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
|
||||
hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);
|
||||
|
||||
// Initialize indices
|
||||
for (uint32_t j = 0; j < ne00; j++) {
|
||||
indices_buf[j] = j;
|
||||
// Initialize indices - Start with values 0..31, add 32 for additional vec iterations
|
||||
HVX_Vector curr_ind_vec = ind_init_vec;
|
||||
for (uint32_t j_vec = 0; j_vec < num_vec_ind_values; j_vec++) {
|
||||
indices_buf_vec[j_vec] = curr_ind_vec;
|
||||
curr_ind_vec = Q6_Vw_vadd_VwVw(curr_ind_vec, ind_diff_vec);
|
||||
}
|
||||
|
||||
// Sort values and mirror swaps to indices
|
||||
|
||||
@@ -736,6 +736,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
|
||||
suffix = ne00 % 4 == 0 ? "_4" : "";
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q1_0:
|
||||
{
|
||||
nsg = N_SG_Q1_0;
|
||||
nr0 = N_R0_Q1_0;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
nsg = N_SG_Q4_0;
|
||||
@@ -948,6 +953,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m
|
||||
smem = 32*sizeof(float)*nr0;
|
||||
suffix = ne00 % 4 == 0 ? "_4" : "";
|
||||
} break;
|
||||
case GGML_TYPE_Q1_0:
|
||||
{
|
||||
nsg = N_SG_Q1_0;
|
||||
nr0 = N_R0_Q1_0;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
nsg = N_SG_Q4_0;
|
||||
|
||||
@@ -1184,6 +1184,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -1210,6 +1211,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
//
|
||||
// TODO: for optimal performance, become function of the device and work size
|
||||
|
||||
#define N_R0_Q1_0 8
|
||||
#define N_SG_Q1_0 2
|
||||
|
||||
#define N_R0_Q4_0 4
|
||||
#define N_SG_Q4_0 2
|
||||
|
||||
|
||||
@@ -2047,6 +2047,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
|
||||
op->src[0]->type == GGML_TYPE_F16 ||
|
||||
op->src[0]->type == GGML_TYPE_BF16 ||
|
||||
op->src[0]->type == GGML_TYPE_Q1_0 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_0 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_1 ||
|
||||
op->src[0]->type == GGML_TYPE_Q5_0 ||
|
||||
|
||||
@@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint8_t * qs = xb->qs;
|
||||
const float d = xb->d;
|
||||
const float neg_d = -d;
|
||||
|
||||
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
|
||||
const uint8_t b0 = qs[byte_offset];
|
||||
const uint8_t b1 = qs[byte_offset + 1];
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
|
||||
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
|
||||
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
|
||||
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
|
||||
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
|
||||
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
|
||||
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
|
||||
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
|
||||
|
||||
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
|
||||
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
|
||||
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
|
||||
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
|
||||
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
|
||||
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
|
||||
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
|
||||
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
|
||||
const float d = xb->d;
|
||||
const float neg_d = -d;
|
||||
const int base = il * 4;
|
||||
const uint8_t byte = xb->qs[base / 8];
|
||||
const int s = base % 8;
|
||||
|
||||
float4 reg_f;
|
||||
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
|
||||
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
|
||||
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
|
||||
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
|
||||
|
||||
reg = (type4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||
@@ -152,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
|
||||
float sum_abs = 0.0f;
|
||||
for (int j = 0; j < QK1_0; j++) {
|
||||
sum_abs += fabs(src[j]);
|
||||
}
|
||||
dst.d = sum_abs / QK1_0;
|
||||
|
||||
for (int j = 0; j < QK1_0 / 8; j++) {
|
||||
dst.qs[j] = 0;
|
||||
}
|
||||
for (int j = 0; j < QK1_0; j++) {
|
||||
if (src[j] >= 0.0f) {
|
||||
dst.qs[j / 8] |= (1 << (j % 8));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
@@ -3116,6 +3183,35 @@ kernel void kernel_group_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
|
||||
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||
device const uint8_t * qs = qb_curr->qs + il / 8;
|
||||
const uint8_t b0 = qs[0];
|
||||
const uint8_t b1 = qs[1];
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
acc += select(0.0f, yl[ 0], bool(b0 & 0x01));
|
||||
acc += select(0.0f, yl[ 1], bool(b0 & 0x02));
|
||||
acc += select(0.0f, yl[ 2], bool(b0 & 0x04));
|
||||
acc += select(0.0f, yl[ 3], bool(b0 & 0x08));
|
||||
acc += select(0.0f, yl[ 4], bool(b0 & 0x10));
|
||||
acc += select(0.0f, yl[ 5], bool(b0 & 0x20));
|
||||
acc += select(0.0f, yl[ 6], bool(b0 & 0x40));
|
||||
acc += select(0.0f, yl[ 7], bool(b0 & 0x80));
|
||||
|
||||
acc += select(0.0f, yl[ 8], bool(b1 & 0x01));
|
||||
acc += select(0.0f, yl[ 9], bool(b1 & 0x02));
|
||||
acc += select(0.0f, yl[10], bool(b1 & 0x04));
|
||||
acc += select(0.0f, yl[11], bool(b1 & 0x08));
|
||||
acc += select(0.0f, yl[12], bool(b1 & 0x10));
|
||||
acc += select(0.0f, yl[13], bool(b1 & 0x20));
|
||||
acc += select(0.0f, yl[14], bool(b1 & 0x40));
|
||||
acc += select(0.0f, yl[15], bool(b1 & 0x80));
|
||||
|
||||
return qb_curr->d * (2.0f * acc - sumy);
|
||||
}
|
||||
|
||||
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
||||
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||
@@ -3337,6 +3433,85 @@ void mul_vec_q_n_f32_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_q1_0_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
const int nb = args.ne00/QK1_0;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * NSG + sgitg) * nr0;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;
|
||||
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
device const block_q1_0 * ax[nr0];
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
|
||||
}
|
||||
|
||||
float yl[16];
|
||||
float sumf[nr0] = {0.f};
|
||||
|
||||
const short ix = (tiisg/8);
|
||||
const short il = (tiisg%8)*16;
|
||||
|
||||
device const float * yb = y + ix*QK1_0 + il;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
|
||||
float sumy = 0.f;
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
yl[i] = yb[i];
|
||||
sumy += yb[i];
|
||||
}
|
||||
|
||||
FOR_UNROLL (short row = 0; row < nr0; row++) {
|
||||
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
|
||||
}
|
||||
|
||||
yb += QK1_0 * (N_SIMDWIDTH/8);
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
|
||||
if (tiisg == 0 && first_row + row < args.ne01) {
|
||||
dst_f32[first_row + row] = tot;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_q1_0_f32")]]
|
||||
kernel void kernel_mul_mv_q1_0_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mv_q4_0_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
@@ -3729,6 +3904,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4
|
||||
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
@@ -7133,6 +7313,7 @@ kernel void kernel_cpy_f32_q(
|
||||
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
|
||||
|
||||
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
|
||||
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
|
||||
@@ -7173,12 +7354,14 @@ kernel void kernel_cpy_q_f32(
|
||||
|
||||
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
|
||||
|
||||
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
|
||||
|
||||
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
|
||||
@@ -9776,6 +9959,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro
|
||||
|
||||
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
||||
|
||||
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
||||
@@ -9838,6 +10022,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
@@ -9861,6 +10046,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
@@ -9893,6 +10079,7 @@ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_m
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
@@ -9916,6 +10103,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
@@ -10070,6 +10258,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
|
||||
|
||||
@@ -589,6 +589,7 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
|
||||
ggml_backend_buffer_free(opt_ctx->buf_cpu);
|
||||
ggml_free(opt_ctx->ctx_static);
|
||||
ggml_free(opt_ctx->ctx_cpu);
|
||||
ggml_free(opt_ctx->ctx_copy);
|
||||
delete opt_ctx;
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,41 @@ static inline int best_index_int8(int n, const int8_t * val, float x) {
|
||||
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||
}
|
||||
|
||||
// reference implementation for deterministic creation of model files
|
||||
void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k) {
|
||||
static const int qk = QK1_0;
|
||||
|
||||
assert(k % qk == 0);
|
||||
|
||||
const int nb = k / qk;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float sum_abs = 0.0f;
|
||||
for (int j = 0; j < qk; j++) {
|
||||
sum_abs += fabsf(x[i*qk + j]);
|
||||
}
|
||||
const float d = sum_abs / qk;
|
||||
|
||||
y[i].d = GGML_FP32_TO_FP16(d);
|
||||
|
||||
// Clear all bits first
|
||||
for (int j = 0; j < qk / 8; ++j) {
|
||||
y[i].qs[j] = 0;
|
||||
}
|
||||
|
||||
// Just store sign of each weight directly (no normalization)
|
||||
for (int j = 0; j < qk; ++j) {
|
||||
const int bit_index = j;
|
||||
const int byte_index = bit_index / 8;
|
||||
const int bit_offset = bit_index % 8;
|
||||
|
||||
if (x[i*qk + j] >= 0.0f) {
|
||||
y[i].qs[byte_index] |= (1 << bit_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reference implementation for deterministic creation of model files
|
||||
void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
|
||||
static const int qk = QK4_0;
|
||||
@@ -339,6 +374,26 @@ void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RE
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
static const int qk = QK1_0;
|
||||
|
||||
assert(k % qk == 0);
|
||||
|
||||
const int nb = k / qk;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d);
|
||||
const float neg_d = -d;
|
||||
|
||||
for (int j = 0; j < qk; ++j) {
|
||||
const int byte_index = j / 8;
|
||||
const int bit_offset = j % 8;
|
||||
const uint8_t bit = (x[i].qs[byte_index] >> bit_offset) & 1;
|
||||
y[i*qk + j] = bit ? d : neg_d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
static const int qk = QK4_0;
|
||||
|
||||
@@ -1978,6 +2033,22 @@ static void quantize_row_q4_0_impl(const float * GGML_RESTRICT x, block_q4_0 * G
|
||||
}
|
||||
}
|
||||
|
||||
size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
if (!quant_weights) {
|
||||
quantize_row_q1_0_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * ggml_row_size(GGML_TYPE_Q1_0, n_per_row);
|
||||
}
|
||||
size_t row_size = ggml_row_size(GGML_TYPE_Q1_0, n_per_row);
|
||||
char * qrow = (char *)dst;
|
||||
for (int64_t row = 0; row < nrow; ++row) {
|
||||
quantize_row_q1_0_ref(src, (block_q1_0*)qrow, n_per_row);
|
||||
src += n_per_row;
|
||||
qrow += row_size;
|
||||
}
|
||||
return nrow * row_size;
|
||||
}
|
||||
|
||||
|
||||
size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
if (!quant_weights) {
|
||||
quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
@@ -5286,6 +5357,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q1_0:
|
||||
{
|
||||
VALIDATE_ROW_DATA_D_F16_IMPL(block_q1_0, data, nb);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb);
|
||||
|
||||
@@ -14,6 +14,7 @@ extern "C" {
|
||||
// NOTE: these functions are defined as GGML_API because they used by the CPU backend
|
||||
|
||||
// Quantization
|
||||
GGML_API void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
|
||||
@@ -41,6 +42,7 @@ GGML_API void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_
|
||||
GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
// Dequantization
|
||||
GGML_API void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
@@ -90,6 +92,7 @@ GGML_API size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTR
|
||||
GGML_API size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
|
||||
@@ -1009,8 +1009,8 @@ public:
|
||||
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
|
||||
|
||||
struct stored_graph {
|
||||
ggml_context_ptr ctx_ptr;
|
||||
ggml_cgraph * graph;
|
||||
std::vector<uint8_t> buffer;
|
||||
ggml_cgraph * graph;
|
||||
};
|
||||
|
||||
private:
|
||||
@@ -1518,10 +1518,12 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
||||
LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
|
||||
|
||||
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
|
||||
|
||||
if (stored_graphs[device].buffer.size() < buf_size) {
|
||||
stored_graphs[device].buffer.resize(buf_size);
|
||||
}
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_size,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.mem_buffer =*/ stored_graphs[device].buffer.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
||||
@@ -1551,7 +1553,6 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
||||
}
|
||||
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
||||
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
|
||||
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
|
||||
stored_graphs[device].graph = graph;
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -143,6 +143,22 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib,
|
||||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
|
||||
static __dpct_inline__ void dequantize_q8_0_reorder(const void *d_ptr, const int64_t ib, const void *qs,
|
||||
const int iqs, dfloat2 &v) {
|
||||
const dfloat d = (const dfloat)*((const sycl::half*)d_ptr + ib);
|
||||
|
||||
v.x() = ((const int8_t *)qs)[iqs + 0];
|
||||
v.y() = ((const int8_t *)qs)[iqs + 1];
|
||||
|
||||
#ifdef GGML_SYCL_F16
|
||||
v.s0() *= d;
|
||||
v.s1() *= d;
|
||||
#else
|
||||
v.x() *= d;
|
||||
v.y() *= d;
|
||||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
|
||||
static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib,
|
||||
const int iqs, dfloat2 &v) {
|
||||
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||
|
||||
@@ -972,6 +972,103 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q8_0_sycl_reorder(const void *vx, const dfloat *y,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
// Q8_0 reorder layout: [all qs (ncols*nrows bytes)][all d values]
|
||||
// Cannot reuse dequantize_mul_mat_vec_reorder template because it has
|
||||
// Q4_0-specific constants hardcoded (d_ptr offset and qs stride).
|
||||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
if (row >= nrows) return;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int iter_stride = 8*2*GGML_SYCL_DMMV_X;
|
||||
const int vals_per_iter = iter_stride / WARP_SIZE;
|
||||
const int ncols_left = ncols % (QK8_0*WARP_SIZE);
|
||||
const int ncols_align = ncols - ncols_left;
|
||||
|
||||
#ifdef GGML_SYCL_F16
|
||||
sycl::half2 tmp = {0.0f, 0.0f};
|
||||
#else
|
||||
float tmp = 0.0f;
|
||||
#endif
|
||||
const char *d_ptr = (const char*)vx + ncols*nrows; // d after all qs
|
||||
|
||||
int i = 0;
|
||||
for (i = 0; i < ncols_align; i += iter_stride) {
|
||||
const int col = i + vals_per_iter*tid;
|
||||
const int ib = (row*ncols + col)/QK8_0;
|
||||
const int iqs = col % QK8_0;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vals_per_iter; j += 2) {
|
||||
dfloat2 v;
|
||||
dequantize_q8_0_reorder((const void *)d_ptr, ib, (const void *)vx,
|
||||
ib * QK8_0 + iqs + j, v);
|
||||
|
||||
#ifdef GGML_SYCL_F16
|
||||
dfloat2 t1{y[col + j + 0], y[col + j + 1]};
|
||||
tmp += v * t1;
|
||||
#else
|
||||
tmp += v.x() * y[col + j + 0];
|
||||
tmp += v.y() * y[col + j + 1];
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// handle remaining columns
|
||||
for (; i < ncols; i += iter_stride) {
|
||||
if (tid >= ncols_left/QK8_0) continue;
|
||||
const int col = i + vals_per_iter*tid;
|
||||
const int ib = (row*ncols + col)/QK8_0;
|
||||
const int iqs = col % QK8_0;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vals_per_iter; j += 2) {
|
||||
dfloat2 v;
|
||||
dequantize_q8_0_reorder((const void *)d_ptr, ib, (const void *)vx,
|
||||
ib * QK8_0 + iqs + j, v);
|
||||
|
||||
#ifdef GGML_SYCL_F16
|
||||
dfloat2 t1{y[col + j + 0], y[col + j + 1]};
|
||||
tmp += v * t1;
|
||||
#else
|
||||
tmp += v.x() * y[col + j + 0];
|
||||
tmp += v.y() * y[col + j + 1];
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// reduce
|
||||
const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
|
||||
for (int mask = mask_start; mask > 0; mask >>= 1) {
|
||||
tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
#ifdef GGML_SYCL_F16
|
||||
dst[row] = tmp.x() + tmp.y();
|
||||
#else
|
||||
dst[row] = tmp;
|
||||
#endif
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
@@ -1122,7 +1219,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
||||
dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
dequantize_mul_mat_vec_q8_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
} else {
|
||||
dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
|
||||
@@ -44,6 +44,10 @@ void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
||||
} break;
|
||||
case 512: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_sycl_flash_attn_ext_tile_case<512, 512>(ctx, dst);
|
||||
} break;
|
||||
case 576: {
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst);
|
||||
|
||||
@@ -67,6 +67,12 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
@@ -124,6 +130,12 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
@@ -131,134 +143,6 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co
|
||||
return 0;
|
||||
}
|
||||
|
||||
static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
|
||||
if(fast_fp16_available(cc))
|
||||
return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols);
|
||||
@@ -1252,6 +1136,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
constexpr int cols_per_block = ncols2*2;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
@@ -1283,6 +1177,16 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_sycl_context & ctx, ggm
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
// ncols2=2 and ncols2=1 fallbacks only for cases where ncols=2 config exists (DKQ == DV).
|
||||
// For DKQ == 576, DV == 512 only GQA-optimized variants are implemented.
|
||||
if constexpr (DKQ == DV) {
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
@@ -1337,5 +1241,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96);
|
||||
extern DECL_FATTN_TILE_CASE(112, 112);
|
||||
extern DECL_FATTN_TILE_CASE(128, 128);
|
||||
extern DECL_FATTN_TILE_CASE(256, 256);
|
||||
extern DECL_FATTN_TILE_CASE(512, 512);
|
||||
extern DECL_FATTN_TILE_CASE(576, 512);
|
||||
|
||||
|
||||
@@ -664,4 +664,11 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q8_0)
|
||||
|
||||
#endif // GGML_SYCL_FATTN_VEC_HPP
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
FATTN_VEC_CASE( 64, type_K, type_V) \
|
||||
FATTN_VEC_CASE(128, type_K, type_V) \
|
||||
FATTN_VEC_CASE(256, type_K, type_V) \
|
||||
FATTN_VEC_CASE(512, type_K, type_V) \
|
||||
|
||||
static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * Q = dst->src[0];
|
||||
@@ -141,6 +142,7 @@ static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const
|
||||
case 128:
|
||||
case 112:
|
||||
case 256:
|
||||
case 512:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
@@ -185,7 +187,7 @@ static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const
|
||||
}
|
||||
|
||||
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 512 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
// Todo: Use the XMX kernel if possible:
|
||||
|
||||
|
||||
@@ -411,11 +411,22 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
||||
assert(tensor->view_src->buffer->buft == buffer->buft);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
|
||||
!g_ggml_sycl_disable_optimize) {
|
||||
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
||||
tensor->extra = extra;
|
||||
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
|
||||
|
||||
if (!g_ggml_sycl_disable_optimize) {
|
||||
// set reorder extra buffer based on supported type
|
||||
switch (tensor->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q6_K:{
|
||||
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
||||
tensor->extra = extra;
|
||||
ctx->tensor_extras.push_back(extra);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_is_quantized(tensor->type)) {
|
||||
@@ -3254,6 +3265,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
||||
inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
return true;
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
@@ -3266,6 +3278,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
|
||||
inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -3275,6 +3288,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
|
||||
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
return true;
|
||||
@@ -3364,6 +3378,40 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
|
||||
sycl_ext_free(stream, tmp_buf);
|
||||
}
|
||||
|
||||
static void reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
|
||||
dpct::queue_ptr stream) {
|
||||
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
|
||||
|
||||
sycl::event copy_event;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
|
||||
if (!g_ggml_sycl_use_async_mem_op) {
|
||||
copy_event.wait();
|
||||
}
|
||||
|
||||
GGML_ASSERT((size % sizeof(block_q8_0) == 0));
|
||||
GGML_ASSERT((offset % sizeof(block_q8_0) == 0));
|
||||
int offset_blks = offset / sizeof(block_q8_0);
|
||||
auto qs_ptr = data_device + offset_blks * QK8_0;
|
||||
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows) + offset_blks;
|
||||
|
||||
auto reorder_event = stream->parallel_for(
|
||||
size / sizeof(block_q8_0),
|
||||
[=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
const block_q8_0* x = (const block_q8_0*)tmp_buf;
|
||||
const int ib = i;
|
||||
|
||||
for (int j = 0; j < QK8_0; j++)
|
||||
{
|
||||
*((int8_t*)qs_ptr + ib * QK8_0 + j) = x[ib].qs[j];
|
||||
}
|
||||
*(d_ptr + ib) = x[ib].d;
|
||||
});
|
||||
if (!g_ggml_sycl_use_async_mem_op) {
|
||||
reorder_event.wait_and_throw();
|
||||
}
|
||||
sycl_ext_free(stream, tmp_buf);
|
||||
}
|
||||
|
||||
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(size % sizeof(block_q4_K) == 0);
|
||||
GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
|
||||
@@ -3460,6 +3508,9 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
reorder_qw_q4_k(data_device, size, 0, stream);
|
||||
break;
|
||||
|
||||
@@ -679,6 +679,25 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
}
|
||||
}
|
||||
|
||||
static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
||||
const int nrows, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK8_0 == 0);
|
||||
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
||||
constexpr size_t num_subgroups = 16;
|
||||
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
||||
|
||||
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
|
||||
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>>(vx, vy, dst, ncols, nrows,
|
||||
nd_item);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
@@ -1101,7 +1120,13 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
||||
mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n");
|
||||
reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
} else {
|
||||
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
|
||||
@@ -105,6 +105,27 @@ template <> struct block_q_t<GGML_TYPE_Q6_K> {
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||
};
|
||||
|
||||
template <> struct block_q_t<GGML_TYPE_Q8_0> {
|
||||
struct traits {
|
||||
static constexpr uint32_t qk = QK8_0; // 32
|
||||
static constexpr uint32_t qi = QI8_0; // 8
|
||||
static constexpr uint32_t qr = QR8_0; // 1
|
||||
static constexpr uint32_t vdr_mmvq = 4;
|
||||
};
|
||||
|
||||
// Q8_0 reorder layout: [qs0|qs1|...|qsN][d0|d1|...|dN]
|
||||
// Each block has 32 int8 weights (32 bytes) followed by all scales
|
||||
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
|
||||
return { block_index * QK8_0, 0 };
|
||||
}
|
||||
|
||||
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
return { (ncols * nrows) + block_index * sizeof(ggml_half), 0 };
|
||||
}
|
||||
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } // 1
|
||||
};
|
||||
|
||||
} // namespace ggml_sycl_reordered
|
||||
|
||||
#endif // GGML_SYCL_QUANTS_HPP
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.hpp"
|
||||
|
||||
DECL_FATTN_TILE_CASE(512, 512);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user