mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-09 16:17:31 +03:00
Compare commits
76 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
15f786e658 | ||
|
|
94ca829b60 | ||
|
|
4aa962e2b0 | ||
|
|
941146b3f1 | ||
|
|
482d862bcb | ||
|
|
3979f2bb08 | ||
|
|
400ac8e194 | ||
|
|
f51fd36d79 | ||
|
|
25eec6f327 | ||
|
|
58190cc84d | ||
|
|
af76639f72 | ||
|
|
761797ffdf | ||
|
|
5d3a4a7da5 | ||
|
|
c08d28d088 | ||
|
|
661e9acb36 | ||
|
|
b8635075ff | ||
|
|
9c699074c9 | ||
|
|
d01f6274c0 | ||
|
|
650bf14eb9 | ||
|
|
b7ad48ebda | ||
|
|
d006858316 | ||
|
|
e439700992 | ||
|
|
50e0ad08fb | ||
|
|
f1f793ad06 | ||
|
|
af5c13841f | ||
|
|
277ff5fff7 | ||
|
|
384c0076bc | ||
|
|
1f34806c44 | ||
|
|
887535c33f | ||
|
|
d3416a4aa9 | ||
|
|
43a4ee4a2c | ||
|
|
f851fa5ab0 | ||
|
|
f1ac84119c | ||
|
|
b069b10ab4 | ||
|
|
0c58ba3365 | ||
|
|
57ace0d612 | ||
|
|
39b27f0da0 | ||
|
|
f49e917876 | ||
|
|
7c7d6ce5c7 | ||
|
|
5208e2d5ba | ||
|
|
7992aa7c8e | ||
|
|
a1cfb64530 | ||
|
|
5803c8d115 | ||
|
|
63f8fe0ef4 | ||
|
|
223373742b | ||
|
|
e15efe007d | ||
|
|
6137c325a1 | ||
|
|
17193cce34 | ||
|
|
d6dac92bfd | ||
|
|
dae2bf41c9 | ||
|
|
bc07d55922 | ||
|
|
4888137b17 | ||
|
|
fbd441c379 | ||
|
|
c30e012253 | ||
|
|
95a6ebabb2 | ||
|
|
12dbf1da95 | ||
|
|
86221cf6da | ||
|
|
6de97b9d3e | ||
|
|
5a0ed5150a | ||
|
|
8710e5f9b9 | ||
|
|
1d6d4cf7a5 | ||
|
|
744c0c7310 | ||
|
|
0356e33aaf | ||
|
|
6422036fcb | ||
|
|
296bc0538b | ||
|
|
6b949d1078 | ||
|
|
84f82e846c | ||
|
|
e1cb817483 | ||
|
|
88d5f8ffc3 | ||
|
|
d43375ff7f | ||
|
|
2b86e5cae6 | ||
|
|
88458164c7 | ||
|
|
4951250235 | ||
|
|
82764c341a | ||
|
|
825eb91a66 | ||
|
|
0fcb3760b2 |
@@ -1,97 +0,0 @@
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG CUDA_VERSION=13.1.1
|
||||
# Target the CUDA build image
|
||||
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
FROM ${BASE_CUDA_DEV_CONTAINER} AS build
|
||||
|
||||
# CUDA architecture to build for (defaults to all supported archs)
|
||||
ARG CUDA_DOCKER_ARCH=default
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y gcc-14 g++-14 build-essential cmake python3 python3-pip git libssl-dev libgomp1
|
||||
|
||||
ENV CC=gcc-14 CXX=g++-14 CUDAHOSTCXX=g++-14
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
|
||||
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \
|
||||
fi && \
|
||||
cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_BUILD_TESTS=OFF ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
&& cp *.py /app/full \
|
||||
&& cp -r gguf-py /app/full \
|
||||
&& cp -r requirements /app/full \
|
||||
&& cp requirements.txt /app/full \
|
||||
&& cp .devops/tools.sh /app/full/tools.sh
|
||||
|
||||
## Base image
|
||||
FROM ${BASE_CUDA_RUN_CONTAINER} AS base
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
&& find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \
|
||||
&& find /var/cache -type f -delete
|
||||
|
||||
COPY --from=build /app/lib/ /app
|
||||
|
||||
### Full
|
||||
FROM base AS full
|
||||
|
||||
COPY --from=build /app/full /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
git \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-wheel \
|
||||
&& pip install --break-system-packages --upgrade setuptools \
|
||||
&& pip install --break-system-packages -r requirements.txt \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
&& find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \
|
||||
&& find /var/cache -type f -delete
|
||||
|
||||
|
||||
ENTRYPOINT ["/app/tools.sh"]
|
||||
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENTRYPOINT [ "/app/llama-cli" ]
|
||||
|
||||
### Server, Server only
|
||||
FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ]
|
||||
|
||||
ENTRYPOINT [ "/app/llama-server" ]
|
||||
@@ -16,7 +16,7 @@
|
||||
rocmPackages,
|
||||
vulkan-headers,
|
||||
vulkan-loader,
|
||||
curl,
|
||||
openssl,
|
||||
shaderc,
|
||||
useBlas ?
|
||||
builtins.all (x: !x) [
|
||||
@@ -160,7 +160,8 @@ effectiveStdenv.mkDerivation (finalAttrs: {
|
||||
++ optionals useMpi [ mpi ]
|
||||
++ optionals useRocm rocmBuildInputs
|
||||
++ optionals useBlas [ blas ]
|
||||
++ optionals useVulkan vulkanBuildInputs;
|
||||
++ optionals useVulkan vulkanBuildInputs
|
||||
++ [ openssl ];
|
||||
|
||||
cmakeFlags =
|
||||
[
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG ROCM_VERSION=7.2
|
||||
ARG AMDGPU_VERSION=7.2
|
||||
ARG ROCM_VERSION=7.2.1
|
||||
ARG AMDGPU_VERSION=7.2.1
|
||||
|
||||
# Target the ROCm build image
|
||||
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
|
||||
@@ -12,11 +12,11 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build
|
||||
|
||||
# Unless otherwise specified, we make a fat build.
|
||||
# This is mostly tied to rocBLAS supported archs.
|
||||
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-7.2.0/reference/system-requirements.html
|
||||
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-7.2.1/reference/system-requirements.html
|
||||
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityrad/native_linux/native_linux_compatibility.html
|
||||
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityryz/native_linux/native_linux_compatibility.html
|
||||
|
||||
ARG ROCM_DOCKER_ARCH='gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201'
|
||||
ARG ROCM_DOCKER_ARCH='gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1150;gfx1200;gfx1201'
|
||||
|
||||
# Set ROCm architectures
|
||||
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
|
||||
|
||||
5
.github/labeler.yml
vendored
5
.github/labeler.yml
vendored
@@ -27,6 +27,11 @@ IBM zDNN:
|
||||
- any-glob-to-any-file:
|
||||
- ggml/include/ggml-zdnn.h
|
||||
- ggml/src/ggml-zdnn/**
|
||||
AMD ZenDNN:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- ggml/include/ggml-zendnn.h
|
||||
- ggml/src/ggml-zendnn/**
|
||||
documentation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
|
||||
38
.github/workflows/build-riscv.yml
vendored
38
.github/workflows/build-riscv.yml
vendored
@@ -35,7 +35,7 @@ env:
|
||||
|
||||
jobs:
|
||||
ubuntu-riscv64-native-sanitizer:
|
||||
runs-on: RISCV64
|
||||
runs-on: ubuntu-24.04-riscv
|
||||
|
||||
continue-on-error: true
|
||||
|
||||
@@ -50,17 +50,18 @@ jobs:
|
||||
sudo apt-get update
|
||||
|
||||
# Install necessary packages
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache git-lfs
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 cmake build-essential wget git-lfs
|
||||
|
||||
# Set gcc-14 and g++-14 as the default compilers
|
||||
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
|
||||
sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100
|
||||
sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc
|
||||
sudo ln -sf /usr/bin/g++-14 /usr/bin/g++
|
||||
|
||||
# Install Rust stable version
|
||||
rustup install stable
|
||||
rustup default stable
|
||||
if ! which rustc; then
|
||||
# Install Rust stable version
|
||||
sudo apt-get install -y rustup
|
||||
rustup install stable
|
||||
rustup default stable
|
||||
fi
|
||||
|
||||
git lfs install
|
||||
|
||||
@@ -73,23 +74,12 @@ jobs:
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
# Unique cache directory per matrix combination
|
||||
export CCACHE_DIR="$HOME/.ccache/sanitizer-${{ matrix.sanitizer }}-${{ matrix.build_type }}"
|
||||
mkdir -p "$CCACHE_DIR"
|
||||
|
||||
# Configure ccache
|
||||
ccache --set-config=max_size=5G
|
||||
ccache --set-config=compression=true
|
||||
ccache --set-config=compression_level=6
|
||||
ccache --set-config=cache_dir="$CCACHE_DIR"
|
||||
ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime
|
||||
ccache --set-config=hash_dir=false
|
||||
|
||||
# Export for subsequent steps
|
||||
echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV
|
||||
echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV
|
||||
# FIXME: Enable when ggml-org/ccache-action works on riscv64
|
||||
# - name: ccache
|
||||
# uses: ggml-org/ccache-action@v1.2.21
|
||||
# with:
|
||||
# key: ubuntu-riscv64-native-sanitizer-${{ matrix.sanytizer }}-${{ matrix.build_type }}
|
||||
# save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
|
||||
21
.github/workflows/build-self-hosted.yml
vendored
21
.github/workflows/build-self-hosted.yml
vendored
@@ -213,6 +213,27 @@ jobs:
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
|
||||
|
||||
ggml-ci-win-intel-vulkan:
|
||||
runs-on: [self-hosted, Windows, X64, Intel]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
shell: C:\msys64\usr\bin\bash.exe --noprofile --norc -eo pipefail "{0}"
|
||||
env:
|
||||
MSYSTEM: UCRT64
|
||||
CHERE_INVOKING: 1
|
||||
PATH: C:\msys64\ucrt64\bin;C:\msys64\usr\bin;C:\Windows\System32;${{ env.PATH }}
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
# Skip python related tests with GG_BUILD_LOW_PERF=1 since Windows MSYS2 UCRT64 currently fails to create
|
||||
# a valid python environment for testing
|
||||
LLAMA_FATAL_WARNINGS=OFF GG_BUILD_NINJA=1 GG_BUILD_VULKAN=1 GG_BUILD_LOW_PERF=1 ./ci/run.sh ./results/llama.cpp ./mnt/llama.cpp
|
||||
|
||||
ggml-ci-intel-openvino-gpu-low-perf:
|
||||
runs-on: [self-hosted, Linux, Intel, OpenVINO]
|
||||
|
||||
|
||||
2
.github/workflows/build-vulkan.yml
vendored
2
.github/workflows/build-vulkan.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
|
||||
- name: Setup Vulkan SDK
|
||||
if: steps.cache-sdk.outputs.cache-hit != 'true'
|
||||
uses: ./.github/actions/linux-setup-vulkan-llvmpipe
|
||||
uses: ./.github/actions/linux-setup-vulkan
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
version: ${{ env.VULKAN_SDK_VERSION }}
|
||||
|
||||
85
.github/workflows/build.yml
vendored
85
.github/workflows/build.yml
vendored
@@ -150,16 +150,15 @@ jobs:
|
||||
- name: Dawn Dependency
|
||||
id: dawn-depends
|
||||
run: |
|
||||
DAWN_VERSION="v2.0.0"
|
||||
DAWN_OWNER="reeselevine"
|
||||
DAWN_VERSION="v20260317.182325"
|
||||
DAWN_OWNER="google"
|
||||
DAWN_REPO="dawn"
|
||||
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release"
|
||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
curl -L -o artifact.zip \
|
||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
DAWN_ASSET_NAME="Dawn-18eb229ef5f707c1464cc581252e7603c73a3ef0-macos-latest-Release"
|
||||
echo "Fetching release asset from https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz"
|
||||
curl -L -o artifact.tar.gz \
|
||||
"https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz"
|
||||
mkdir dawn
|
||||
unzip artifact.zip
|
||||
tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
|
||||
tar -xvf artifact.tar.gz -C dawn --strip-components=1
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -384,16 +383,15 @@ jobs:
|
||||
id: dawn-depends
|
||||
run: |
|
||||
sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
|
||||
DAWN_VERSION="v2.0.0"
|
||||
DAWN_OWNER="reeselevine"
|
||||
DAWN_VERSION="v20260317.182325"
|
||||
DAWN_OWNER="google"
|
||||
DAWN_REPO="dawn"
|
||||
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release"
|
||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
curl -L -o artifact.zip \
|
||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
DAWN_ASSET_NAME="Dawn-18eb229ef5f707c1464cc581252e7603c73a3ef0-ubuntu-latest-Release"
|
||||
echo "Fetching release asset from https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz"
|
||||
curl -L -o artifact.tar.gz \
|
||||
"https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz"
|
||||
mkdir dawn
|
||||
unzip artifact.zip
|
||||
tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
|
||||
tar -xvf artifact.tar.gz -C dawn --strip-components=1
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -427,7 +425,7 @@ jobs:
|
||||
|
||||
- name: Fetch emdawnwebgpu
|
||||
run: |
|
||||
DAWN_TAG="v20251027.212519"
|
||||
DAWN_TAG="v20260317.182325"
|
||||
EMDAWN_PKG="emdawnwebgpu_pkg-${DAWN_TAG}.zip"
|
||||
echo "Downloading ${EMDAWN_PKG}"
|
||||
curl -L -o emdawn.zip \
|
||||
@@ -474,6 +472,7 @@ jobs:
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON \
|
||||
-DGPU_TARGETS="gfx1030" \
|
||||
-DGGML_HIP=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
@@ -943,7 +942,7 @@ jobs:
|
||||
- name: Grab rocWMMA package
|
||||
id: grab_rocwmma
|
||||
run: |
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb"
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2.1/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70201-81~24.04_amd64.deb"
|
||||
7z x rocwmma.deb
|
||||
7z x data.tar
|
||||
|
||||
@@ -986,17 +985,18 @@ jobs:
|
||||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.1/include/" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DLLAMA_BUILD_BORINGSSL=ON `
|
||||
-DROCM_DIR="${env:HIP_PATH}" `
|
||||
-DGGML_HIP=ON `
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON `
|
||||
-DGPU_TARGETS="gfx1100" `
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
|
||||
|
||||
ubuntu-cpu-riscv64-native:
|
||||
runs-on: RISCV64
|
||||
runs-on: ubuntu-24.04-riscv
|
||||
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
@@ -1004,24 +1004,21 @@ jobs:
|
||||
sudo apt-get update
|
||||
|
||||
# Install necessary packages
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache git-lfs
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 cmake build-essential libssl-dev wget git-lfs
|
||||
|
||||
# Set gcc-14 and g++-14 as the default compilers
|
||||
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
|
||||
sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100
|
||||
sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc
|
||||
sudo ln -sf /usr/bin/g++-14 /usr/bin/g++
|
||||
|
||||
# Install Rust stable version
|
||||
rustup install stable
|
||||
rustup default stable
|
||||
if ! which rustc; then
|
||||
# Install Rust stable version
|
||||
sudo apt-get install -y rustup
|
||||
rustup install stable
|
||||
rustup default stable
|
||||
fi
|
||||
|
||||
git lfs install
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Check environment
|
||||
run: |
|
||||
uname -a
|
||||
@@ -1031,25 +1028,17 @@ jobs:
|
||||
cmake --version
|
||||
rustc --version
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
# Set unique cache directory for this job
|
||||
export CCACHE_DIR="$HOME/.ccache/cpu-cmake-rv64-native"
|
||||
mkdir -p "$CCACHE_DIR"
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
# Configure ccache for optimal performance
|
||||
ccache --set-config=max_size=5G
|
||||
ccache --set-config=compression=true
|
||||
ccache --set-config=compression_level=6
|
||||
ccache --set-config=cache_dir="$CCACHE_DIR"
|
||||
|
||||
# Enable more aggressive caching
|
||||
ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime
|
||||
ccache --set-config=hash_dir=false
|
||||
|
||||
# Export for subsequent steps
|
||||
echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV
|
||||
echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV
|
||||
# FIXME: Enable when ggml-org/ccache-action works on riscv64
|
||||
# - name: ccache
|
||||
# uses: ggml-org/ccache-action@v1.2.21
|
||||
# with:
|
||||
# key: ubuntu-cpu-riscv64-native
|
||||
# evict-old-files: 1d
|
||||
# save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
|
||||
8
.github/workflows/docker.yml
vendored
8
.github/workflows/docker.yml
vendored
@@ -73,10 +73,10 @@ jobs:
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda-new.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda-new.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.1.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.1.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "musa", "dockerfile": ".devops/musa.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "intel", "dockerfile": ".devops/intel.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "vulkan", "dockerfile": ".devops/vulkan.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
|
||||
|
||||
4
.github/workflows/hip-quality-check.yml
vendored
4
.github/workflows/hip-quality-check.yml
vendored
@@ -35,7 +35,7 @@ env:
|
||||
jobs:
|
||||
ubuntu-22-hip-quality-check:
|
||||
runs-on: ubuntu-22.04
|
||||
container: rocm/dev-ubuntu-22.04:7.2
|
||||
container: rocm/dev-ubuntu-22.04:7.2.1
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
@@ -59,7 +59,7 @@ jobs:
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGPU_TARGETS=gfx908 \
|
||||
-DGPU_TARGETS=gfx942 \
|
||||
-DGGML_HIP=ON \
|
||||
-DGGML_HIP_EXPORT_METRICS=Off \
|
||||
-DCMAKE_HIP_FLAGS="-Werror -Wno-tautological-compare" \
|
||||
|
||||
22
.github/workflows/release.yml
vendored
22
.github/workflows/release.yml
vendored
@@ -639,8 +639,8 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- ROCM_VERSION: "7.2"
|
||||
gpu_targets: "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201"
|
||||
- ROCM_VERSION: "7.2.1"
|
||||
gpu_targets: "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1150;gfx1200;gfx1201"
|
||||
build: 'x64'
|
||||
|
||||
steps:
|
||||
@@ -662,7 +662,7 @@ jobs:
|
||||
sudo apt install -y build-essential git cmake wget
|
||||
|
||||
- name: Setup Legacy ROCm
|
||||
if: matrix.ROCM_VERSION == '7.2'
|
||||
if: matrix.ROCM_VERSION == '7.2.1'
|
||||
id: legacy_env
|
||||
run: |
|
||||
sudo mkdir --parents --mode=0755 /etc/apt/keyrings
|
||||
@@ -683,7 +683,7 @@ jobs:
|
||||
sudo apt-get install -y libssl-dev rocm-hip-sdk
|
||||
|
||||
- name: Setup TheRock
|
||||
if: matrix.ROCM_VERSION != '7.2'
|
||||
if: matrix.ROCM_VERSION != '7.2.1'
|
||||
id: therock_env
|
||||
run: |
|
||||
wget https://repo.amd.com/rocm/tarball/therock-dist-linux-gfx1151-${{ matrix.ROCM_VERSION }}.tar.gz
|
||||
@@ -699,7 +699,6 @@ jobs:
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DCMAKE_HIP_FLAGS="-mllvm --amdgpu-unroll-threshold-local=600" \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_NATIVE=OFF \
|
||||
@@ -717,17 +716,20 @@ jobs:
|
||||
id: tag
|
||||
uses: ./.github/actions/get-tag-name
|
||||
|
||||
- name: Get ROCm short version
|
||||
run: echo "ROCM_VERSION_SHORT=$(echo '${{ matrix.ROCM_VERSION }}' | cut -d '.' -f 1,2)" >> $GITHUB_ENV
|
||||
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ env.ROCM_VERSION_SHORT }}-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
|
||||
name: llama-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ env.ROCM_VERSION_SHORT }}-${{ matrix.build }}.tar.gz
|
||||
name: llama-bin-ubuntu-rocm-${{ env.ROCM_VERSION_SHORT }}-${{ matrix.build }}.tar.gz
|
||||
|
||||
windows-hip:
|
||||
runs-on: windows-2022
|
||||
@@ -749,7 +751,7 @@ jobs:
|
||||
- name: Grab rocWMMA package
|
||||
id: grab_rocwmma
|
||||
run: |
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb"
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2.1/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70201-81~24.04_amd64.deb"
|
||||
7z x rocwmma.deb
|
||||
7z x data.tar
|
||||
|
||||
@@ -806,7 +808,7 @@ jobs:
|
||||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.1/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DGGML_BACKEND_DL=ON `
|
||||
-DGGML_NATIVE=OFF `
|
||||
|
||||
120
AGENTS.md
120
AGENTS.md
@@ -5,78 +5,106 @@
|
||||
>
|
||||
> Read more: [CONTRIBUTING.md](CONTRIBUTING.md)
|
||||
|
||||
AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (see examples below)
|
||||
AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (see examples below).
|
||||
|
||||
---
|
||||
|
||||
## Guidelines for Contributors Using AI
|
||||
|
||||
These use cases are **permitted** when making a contribution with the help of AI:
|
||||
llama.cpp is built by humans, for humans. Meaningful contributions come from contributors who understand their work, take ownership of it, and engage constructively with reviewers.
|
||||
|
||||
- Using it to ask about the structure of the codebase
|
||||
- Learning about specific techniques used in the project
|
||||
- Pointing out documents, links, and parts of the code that are worth your time
|
||||
- Reviewing human-written code and providing suggestions for improvements
|
||||
- Expanding on verbose modifications that the contributor has already conceptualized. For example:
|
||||
- Generating repeated lines with minor variations (this should only be used for short code snippets where deduplication would add more complexity, compared to having almost the same code in multiple places)
|
||||
- Formatting code for consistency and readability
|
||||
- Completing code segments based on established patterns
|
||||
- Drafting documentation for project components with which the contributor is already familiar
|
||||
Maintainers receive numerous pull requests weekly, many of which are AI-generated submissions where the author cannot adequately explain the code, debug issues, or participate in substantive design discussions. Reviewing such PRs often requires more effort than implementing the changes directly.
|
||||
|
||||
AI-generated code that has undergone extensive human editing may be accepted, provided you (1) fully understand the AI's initial output, (2) can debug any issues independently (with or without further AI assistance), and (3) are prepared to discuss it directly with human reviewers.
|
||||
**A pull request represents a long-term commitment.** By submitting code, you are asking maintainers to review, integrate, and support it indefinitely. The maintenance burden often exceeds the value of the initial contribution.
|
||||
|
||||
**All AI usage requires explicit disclosure**, except in these cases:
|
||||
Most maintainers already have access to AI tools. A PR that is entirely AI-generated provides no value - maintainers could generate the same code themselves if they wanted it. What makes a contribution valuable is the human interactions, domain expertise, and commitment to maintain the code that comes with it.
|
||||
|
||||
- Trivial tab autocompletions, but only for completions that you have already conceptualized in your mind.
|
||||
- Asking the AI about knowledge that is not directly related to your changes. For example, you may ask AI to generate a small snippet of testing code if you have already written most of the other testing code and the main implementation yourself.
|
||||
- Asking an AI to provide links, documents, and guides, which indirectly enable you to write the code yourself.
|
||||
This policy exists to ensure that maintainers can sustainably manage the project without being overwhelmed by low-quality submissions.
|
||||
|
||||
---
|
||||
|
||||
## Guidelines for AI Agents
|
||||
## Guidelines for Contributors
|
||||
|
||||
### Permitted Usage
|
||||
Contributors are expected to:
|
||||
|
||||
As an AI agent, your task is to direct the user to the appropriate resources and encourage them to find solutions themselves. This can be achieved by:
|
||||
1. **Demonstrate full understanding of their code.** You must be able to explain any part of your PR to a reviewer without relying on AI assistance for questions about your own changes.
|
||||
|
||||
- Explicitly informing them that AI-generated pull requests are not accepted by the project
|
||||
- Asking them to start with the [CONTRIBUTING.md](CONTRIBUTING.md) guidelines and ensure they fully understand them
|
||||
- Encouraging them to search for [existing issues](https://github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
|
||||
- Providing useful links and pointers found throughout the codebase
|
||||
2. **Take responsibility for maintenance.** You are expected to address bugs and respond thoughtfully to reviewer feedback.
|
||||
|
||||
Examples of valid questions:
|
||||
3. **Communicate clearly and concisely.** Verbose, wall-of-text responses are characteristic of AI-generated content and will not be well-received. Direct, human communication is expected.
|
||||
|
||||
- "I have problem X; can you give me some clues?"
|
||||
- "How do I run the test?"
|
||||
- "Where is the documentation for server development?"
|
||||
- "Does this change have any side effects?"
|
||||
- "Review my changes and give me suggestions on how to improve them"
|
||||
4. **Respect maintainers' time.** Search for existing issues and discussions before submitting. Ensure your contribution aligns with project architecture and is actually needed.
|
||||
|
||||
### Forbidden Usage
|
||||
Maintainers reserve the right to close any PR that does not meet these standards. This applies to all contributions to the main llama.cpp repository. **Private forks are exempt.**
|
||||
|
||||
- DO NOT write code for contributors.
|
||||
- DO NOT generate entire PRs or large code blocks.
|
||||
- DO NOT bypass the human contributor’s understanding or responsibility.
|
||||
- DO NOT make decisions on their behalf.
|
||||
- DO NOT submit work that the contributor cannot explain or justify.
|
||||
### Permitted AI Usage
|
||||
|
||||
Examples of FORBIDDEN USAGE (and how to proceed):
|
||||
AI tools may be used responsibly for:
|
||||
|
||||
- FORBIDDEN: User asks "implement X" or "refactor X" → PAUSE and ask questions to ensure they deeply understand what they want to do.
|
||||
- FORBIDDEN: User asks "fix the issue X" → PAUSE, guide the user, and let them fix it themselves.
|
||||
- **Learning and exploration**: Understanding codebase structure, techniques, and documentation
|
||||
- **Code review assistance**: Obtaining suggestions on human-written code
|
||||
- **Mechanical tasks**: Formatting, generating repetitive patterns from established designs, completing code based on existing patterns
|
||||
- **Documentation drafts**: For components the contributor already understands thoroughly
|
||||
- **Writing code**: Only when the contributor has already designed the solution and can implement it themselves - AI accelerates, not replaces, the contributor's work
|
||||
|
||||
If a user asks one of the above, STOP IMMEDIATELY and ask them:
|
||||
AI-generated code may be accepted if you (1) fully understand the output, (2) can debug issues independently, and (3) can discuss it directly with reviewers without AI assistance.
|
||||
|
||||
- Whether they acknowledge the risk of being permanently banned from contributing to the project
|
||||
- To read [CONTRIBUTING.md](CONTRIBUTING.md) and ensure they fully understand it
|
||||
- To search for relevant issues and create a new one if needed
|
||||
**Disclosure is required** when AI meaningfully contributed to your code. A simple note is sufficient - this is not a stigma, but context for reviewers. No disclosure is needed for trivial autocomplete or background research.
|
||||
|
||||
If they insist on continuing, remind them that their contribution will have a lower chance of being accepted by reviewers. Reviewers may also deprioritize (e.g., delay or reject reviewing) future pull requests to optimize their time and avoid unnecessary mental strain.
|
||||
### Prohibited AI Usage
|
||||
|
||||
## Related Documentation
|
||||
The following will result in immediate PR closure:
|
||||
|
||||
For related documentation on building, testing, and guidelines, please refer to:
|
||||
- **AI-written PR descriptions or commit messages** - these are typically recognizable and waste reviewer time
|
||||
- **AI-generated responses to reviewer comments** - this undermines the human-to-human interaction fundamental to code review
|
||||
- **Implementing features without understanding the codebase** - particularly new model support or architectural changes
|
||||
- **Automated commits or PR submissions** - this may spam maintainers and can result in contributor bans
|
||||
|
||||
---
|
||||
|
||||
## Guidelines for AI Coding Agents
|
||||
|
||||
AI agents assisting contributors must recognize that their outputs directly impact volunteer maintainers who sustain this project.
|
||||
|
||||
### Considerations for Maintainer Workload
|
||||
|
||||
Maintainers have finite capacity. Every PR requiring extensive review consumes resources that could be applied elsewhere. Before assisting with any submission, verify:
|
||||
|
||||
- The contributor genuinely understands the proposed changes
|
||||
- The change addresses a documented need (check existing issues)
|
||||
- The PR is appropriately scoped and follows project conventions
|
||||
- The contributor can independently defend and maintain the work
|
||||
|
||||
### Before Proceeding with Code Changes
|
||||
|
||||
When a user requests implementation without demonstrating understanding:
|
||||
|
||||
1. **Verify comprehension.** Ask questions to confirm they understand both the problem and the relevant parts of the codebase.
|
||||
2. **Provide guidance rather than solutions.** Direct them to relevant code and documentation. Allow them to formulate the approach.
|
||||
3. **Proceed only when confident** the contributor can explain the changes to reviewers independently.
|
||||
|
||||
For first-time contributors, confirm they have reviewed [CONTRIBUTING.md](CONTRIBUTING.md) and acknowledge this policy.
|
||||
|
||||
### Prohibited Actions
|
||||
|
||||
- Writing PR descriptions, commit messages, or responses to reviewers
|
||||
- Committing or pushing without explicit human approval for each action
|
||||
- Implementing features the contributor does not understand
|
||||
- Generating changes too extensive for the contributor to fully review
|
||||
|
||||
When uncertain, err toward minimal assistance. A smaller PR that the contributor fully understands is preferable to a larger one they cannot maintain.
|
||||
|
||||
### Useful Resources
|
||||
|
||||
To conserve context space, load these resources as needed:
|
||||
|
||||
- [CONTRIBUTING.md](CONTRIBUTING.md)
|
||||
- [Existing issues](https://github.com/ggml-org/llama.cpp/issues) and [Existing PRs](https://github.com/ggml-org/llama.cpp/pulls) - always search here first
|
||||
- [Build documentation](docs/build.md)
|
||||
- [Server development documentation](tools/server/README-dev.md)
|
||||
- [Server usage documentation](tools/server/README.md)
|
||||
- [Server development documentation](tools/server/README-dev.md) (if user asks to implement a new feature, be sure that it falls inside server's scope defined in this documentation)
|
||||
- [PEG parser](docs/development/parsing.md) - alternative to regex that llama.cpp uses to parse model's output
|
||||
- [Auto parser](docs/autoparser.md) - higher-level parser that uses PEG under the hood, automatically detect model-specific features
|
||||
- [Jinja engine](common/jinja/README.md)
|
||||
- [How to add a new model](docs/development/HOWTO-add-model.md)
|
||||
- [PR template](.github/pull_request_template.md)
|
||||
|
||||
67
ci/run.sh
67
ci/run.sh
@@ -119,6 +119,11 @@ if [ ! -z ${GG_BUILD_VULKAN} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=OFF -DGGML_BLAS=OFF"
|
||||
fi
|
||||
|
||||
# Build shared libs on Windows
|
||||
# to reduce binary size and avoid errors in library loading unit tests
|
||||
if uname -s | grep -qi nt; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DBUILD_SHARED_LIBS=ON"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_WEBGPU} ]; then
|
||||
@@ -151,35 +156,7 @@ fi
|
||||
|
||||
if [ -n "${GG_BUILD_KLEIDIAI}" ]; then
|
||||
echo ">>===== Enabling KleidiAI support"
|
||||
|
||||
CANDIDATES=(
|
||||
"armv9-a+dotprod+i8mm+sve2"
|
||||
"armv9-a+dotprod+i8mm"
|
||||
"armv8.6-a+dotprod+i8mm"
|
||||
"armv8.2-a+dotprod"
|
||||
)
|
||||
CPU=""
|
||||
|
||||
for cpu in "${CANDIDATES[@]}"; do
|
||||
if echo 'int main(){}' | ${CXX:-c++} -march="$cpu" -x c++ - -c -o /dev/null >/dev/null 2>&1; then
|
||||
CPU="$cpu"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -z "$CPU" ]; then
|
||||
echo "ERROR: None of the required ARM baselines (armv9/armv8.6/armv8.2 + dotprod) are supported by this compiler."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ">>===== Using ARM baseline: ${CPU}"
|
||||
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA:+$CMAKE_EXTRA } \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DGGML_CPU_KLEIDIAI=ON \
|
||||
-DGGML_CPU_AARCH64=ON \
|
||||
-DGGML_CPU_ARM_ARCH=${CPU} \
|
||||
-DBUILD_SHARED_LIBS=OFF"
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA:+$CMAKE_EXTRA } -DGGML_CPU_KLEIDIAI=ON"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_BLAS} ]; then
|
||||
@@ -249,7 +226,7 @@ function gg_run_ctest_debug {
|
||||
|
||||
set -e
|
||||
|
||||
# Check cmake and ctest are installed
|
||||
# Check required binaries are installed
|
||||
gg_check_build_requirements
|
||||
|
||||
(cmake -G "${CMAKE_GENERATOR}" -DCMAKE_BUILD_TYPE=Debug ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
|
||||
@@ -280,7 +257,7 @@ function gg_run_ctest_release {
|
||||
|
||||
set -e
|
||||
|
||||
# Check cmake and ctest are installed
|
||||
# Check required binaries are installed
|
||||
gg_check_build_requirements
|
||||
|
||||
(cmake -G "${CMAKE_GENERATOR}" -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
|
||||
@@ -655,10 +632,38 @@ function gg_sum_rerank_tiny {
|
||||
}
|
||||
|
||||
function gg_check_build_requirements {
|
||||
if ! command -v git &> /dev/null; then
|
||||
gg_printf 'git not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v git-lfs &> /dev/null; then
|
||||
gg_printf 'git-lfs not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v wget &> /dev/null; then
|
||||
gg_printf 'wget not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v python3 &> /dev/null; then
|
||||
gg_printf 'python3 not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v pip3 &> /dev/null; then
|
||||
gg_printf 'pip3 not found, please install'
|
||||
fi
|
||||
|
||||
if ! python3 -m ensurepip --help &> /dev/null; then
|
||||
gg_printf 'ensurepip not found, please install python3-venv package'
|
||||
fi
|
||||
|
||||
if ! command -v cmake &> /dev/null; then
|
||||
gg_printf 'cmake not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v ccache &> /dev/null; then
|
||||
gg_printf 'ccache not found, please consider installing for faster builds'
|
||||
fi
|
||||
|
||||
if ! command -v ctest &> /dev/null; then
|
||||
gg_printf 'ctest not found, please install'
|
||||
fi
|
||||
|
||||
@@ -537,9 +537,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("HF cache migration failed: %s\n", e.what());
|
||||
}
|
||||
// export_graph_ops loads only metadata
|
||||
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
|
||||
// maybe handle remote preset
|
||||
if (!params.model.hf_repo.empty()) {
|
||||
if (!params.model.hf_repo.empty() && !skip_model_download) {
|
||||
std::string cli_hf_repo = params.model.hf_repo;
|
||||
bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex);
|
||||
|
||||
@@ -570,7 +572,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
}
|
||||
|
||||
// handle model and download
|
||||
{
|
||||
if (!skip_model_download) {
|
||||
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
|
||||
if (params.no_mmproj) {
|
||||
params.mmproj = {};
|
||||
@@ -591,7 +593,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) {
|
||||
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !skip_model_download && !params.usage && !params.completion) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
}
|
||||
|
||||
@@ -1309,6 +1311,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.kv_unified = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
|
||||
add_opt(common_arg(
|
||||
{"--clear-idle"},
|
||||
{"--no-clear-idle"},
|
||||
"save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)",
|
||||
[](common_params & params, bool value) {
|
||||
params.clear_idle = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_CLEAR_IDLE").set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--context-shift"},
|
||||
{"--no-context-shift"},
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
@@ -92,6 +93,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons
|
||||
|
||||
ctx.extracting_reasoning = extract_reasoning && reasoning.mode != reasoning_mode::NONE;
|
||||
ctx.content = &content;
|
||||
ctx.reasoning = &reasoning;
|
||||
|
||||
// Build reasoning parser
|
||||
ctx.reasoning_parser = reasoning.build_parser(ctx);
|
||||
@@ -100,6 +102,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons
|
||||
|
||||
bool has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
|
||||
bool pure_content = reasoning.mode == reasoning_mode::NONE;
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
|
||||
@@ -107,12 +110,14 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons
|
||||
p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"),
|
||||
response_format
|
||||
}) + p.end();
|
||||
pure_content = false;
|
||||
} else if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
|
||||
parser = tools.build_parser(ctx);
|
||||
pure_content = false;
|
||||
} else {
|
||||
parser = content.build_parser(ctx);
|
||||
}
|
||||
return p.prefix(inputs.generation_prompt, reasoning.start) + parser;
|
||||
return pure_content ? p.prefix(inputs.generation_prompt, reasoning.start) + parser : p.prefix(inputs.generation_prompt, reasoning.start) << parser;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -211,6 +216,44 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont
|
||||
p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p, const std::string & name,
|
||||
const common_peg_parser & call_id_section, bool have_call_id,
|
||||
const common_peg_parser & args,
|
||||
std::optional<common_peg_parser> atomic_peek) const {
|
||||
auto open = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix);
|
||||
bool matched_atomic = false;
|
||||
common_peg_parser func_parser = p.eps();
|
||||
|
||||
if (!function.name_suffix.empty()) {
|
||||
func_parser = open + call_id_section + p.space() + args;
|
||||
matched_atomic = true;
|
||||
} else if (have_call_id) {
|
||||
func_parser = p.atomic(open + call_id_section) + p.space() + args;
|
||||
matched_atomic = true;
|
||||
} else if (atomic_peek.has_value()) {
|
||||
func_parser = p.atomic(open + call_id_section + p.space() + *atomic_peek) + args;
|
||||
matched_atomic = true;
|
||||
} else {
|
||||
func_parser = open + call_id_section + p.space() + args;
|
||||
}
|
||||
|
||||
if (!function.close.empty()) {
|
||||
func_parser = func_parser + p.space() + p.tool_close(p.literal(function.close));
|
||||
} else if (!format.per_call_end.empty()) {
|
||||
// When there's no func_close but there is a per_call_end marker, use peek() to ensure
|
||||
// we only emit tool_close when we can actually see the closing marker. This prevents
|
||||
// premature closing during partial parsing when we've seen e.g. "</" which could be
|
||||
// either "</tool_call>" (end) or "<arg_key>" prefix that failed to match.
|
||||
func_parser = func_parser + p.tool_close(p.peek(p.literal(format.per_call_end)));
|
||||
} else {
|
||||
func_parser = func_parser + p.tool_close(p.space()); // force this to process tool closing callbacks in mapper
|
||||
}
|
||||
if (!matched_atomic) {
|
||||
func_parser = p.atomic(func_parser);
|
||||
}
|
||||
return func_parser;
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
@@ -224,17 +267,27 @@ common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context
|
||||
const auto & schema = func.contains("parameters") ? func.at("parameters") : json::object();
|
||||
|
||||
// Build call_id parser based on position (if supported)
|
||||
bool have_call_id = false;
|
||||
common_peg_parser call_id_section = p.eps();
|
||||
if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() &&
|
||||
!call_id.suffix.empty()) {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix;
|
||||
(!call_id.suffix.empty() || !arguments.start.empty())) {
|
||||
if (!call_id.suffix.empty()) {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix;
|
||||
} else {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(arguments.start)));
|
||||
}
|
||||
have_call_id = true;
|
||||
}
|
||||
auto args_parser = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema));
|
||||
if (!arguments.start.empty()) {
|
||||
args_parser = p.literal(arguments.start) + args_parser;
|
||||
}
|
||||
if (!arguments.end.empty()) {
|
||||
args_parser = args_parser + p.literal(arguments.end);
|
||||
}
|
||||
|
||||
auto func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema));
|
||||
if (!function.close.empty()) {
|
||||
func_parser = func_parser + function.close;
|
||||
}
|
||||
auto atomic_peek = !arguments.start.empty() ? std::optional(p.peek(p.literal(arguments.start))) : std::nullopt;
|
||||
auto func_parser = build_func_parser(p, name, call_id_section, have_call_id, args_parser, atomic_peek);
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
@@ -294,12 +347,34 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
for (const auto & [param_name, param_schema] : properties.items()) {
|
||||
bool is_required = required.find(param_name) != required.end();
|
||||
std::string type = "object";
|
||||
auto type_obj = param_schema.contains("type") ? param_schema.at("type") : json::object();
|
||||
if (type_obj.is_string()) {
|
||||
type_obj.get_to(type);
|
||||
} else if (type_obj.is_object()) {
|
||||
if (type_obj.contains("type") && type_obj.at("type").is_string()) {
|
||||
type_obj.at("type").get_to(type);
|
||||
if (param_schema.contains("type")) {
|
||||
const auto & type_obj = param_schema.at("type");
|
||||
if (type_obj.is_string()) {
|
||||
type_obj.get_to(type);
|
||||
} else if (type_obj.is_array()) {
|
||||
// Handle nullable types like ["string", "null"]
|
||||
for (const auto & t : type_obj) {
|
||||
if (t.is_string() && t.get<std::string>() != "null") {
|
||||
type = t.get<std::string>();
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if (type_obj.is_object()) {
|
||||
if (type_obj.contains("type") && type_obj.at("type").is_string()) {
|
||||
type_obj.at("type").get_to(type);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Infer string type from enum values when type is unspecified
|
||||
if (type == "object" && param_schema.contains("enum")) {
|
||||
const auto & enum_vals = param_schema.at("enum");
|
||||
if (enum_vals.is_array()) {
|
||||
for (const auto & v : enum_vals) {
|
||||
if (v.is_string()) {
|
||||
type = "string";
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -342,52 +417,31 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
args_seq = args_seq + p.repeat(p.space() + any_opt, 0, (int) optional_parsers.size());
|
||||
}
|
||||
|
||||
if (!arguments.start.empty()) {
|
||||
args_seq = p.literal(arguments.start) + args_seq;
|
||||
}
|
||||
if (!arguments.end.empty()) {
|
||||
args_seq = args_seq + p.literal(arguments.end);
|
||||
}
|
||||
|
||||
// Build call_id parser based on position (if supported)
|
||||
common_peg_parser call_id_section = p.eps();
|
||||
bool have_call_id = false;
|
||||
if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() &&
|
||||
!call_id.suffix.empty()) {
|
||||
(!call_id.suffix.empty() || !arguments.start.empty())) {
|
||||
have_call_id = true;
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix)) + call_id.suffix);
|
||||
}
|
||||
|
||||
bool matched_atomic = false;
|
||||
common_peg_parser func_parser = p.eps();
|
||||
if (!function.name_suffix.empty()) {
|
||||
func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + args_seq;
|
||||
matched_atomic = true;
|
||||
} else if (have_call_id) {
|
||||
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section) + p.space() + args_seq;
|
||||
matched_atomic = true;
|
||||
} else if (!arguments.name_prefix.empty() && !required_parsers.empty()) {
|
||||
// Only peek for an arg tag when there are required args that must follow.
|
||||
// When all args are optional, the model may emit no arg tags at all (#20650).
|
||||
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq;
|
||||
matched_atomic = true;
|
||||
} else {
|
||||
func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + args_seq;
|
||||
}
|
||||
|
||||
if (!function.close.empty()) {
|
||||
func_parser = func_parser + p.space() + p.tool_close(p.literal(function.close));
|
||||
} else if (!format.per_call_end.empty()) {
|
||||
// When there's no func_close but there is a per_call_end marker, use peek() to ensure
|
||||
// we only emit tool_close when we can actually see the closing marker. This prevents
|
||||
// premature closing during partial parsing when we've seen e.g. "</" which could be
|
||||
// either "</tool_call>" (end) or "<arg_key>" prefix that failed to match.
|
||||
func_parser = func_parser + p.tool_close(p.peek(p.literal(format.per_call_end)));
|
||||
} else {
|
||||
func_parser =
|
||||
func_parser + p.tool_close(p.space()); // force this to process tool closing callbacks in mapper
|
||||
}
|
||||
if (!matched_atomic) {
|
||||
func_parser = p.atomic(func_parser);
|
||||
if (!call_id.suffix.empty()) {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix)) + call_id.suffix);
|
||||
} else {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(arguments.start)));
|
||||
}
|
||||
}
|
||||
|
||||
// Only peek for an arg tag when there are required args that must follow.
|
||||
// When all args are optional, the model may emit no arg tags at all (#20650).
|
||||
auto atomic_peek = (!arguments.name_prefix.empty() && !required_parsers.empty()) ?
|
||||
std::optional(p.peek(p.literal(arguments.name_prefix))) : std::nullopt;
|
||||
auto func_parser = build_func_parser(p, name, call_id_section, have_call_id, args_seq, atomic_peek);
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "chat-auto-parser.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "common.h"
|
||||
#include "jinja/caps.h"
|
||||
#include "peg-parser.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include <chrono>
|
||||
#include <optional>
|
||||
@@ -212,12 +213,14 @@ struct tool_id_analysis {
|
||||
// ============================================================================
|
||||
|
||||
struct analyze_content;
|
||||
struct analyze_reasoning;
|
||||
|
||||
struct parser_build_context {
|
||||
common_chat_peg_builder & p;
|
||||
const generation_params & inputs;
|
||||
const generation_params & inputs;
|
||||
common_peg_parser reasoning_parser;
|
||||
bool extracting_reasoning = false;
|
||||
const analyze_reasoning * reasoning = nullptr;
|
||||
const analyze_content * content = nullptr;
|
||||
|
||||
parser_build_context(common_chat_peg_builder & p, const generation_params & inputs);
|
||||
@@ -350,6 +353,13 @@ struct analyze_tools : analyze_base {
|
||||
common_peg_parser build_tool_parser_json_native(parser_build_context & ctx) const;
|
||||
common_peg_parser build_tool_parser_tag_json(parser_build_context & ctx) const;
|
||||
common_peg_parser build_tool_parser_tag_tagged(parser_build_context & ctx) const;
|
||||
|
||||
// Shared helper: builds func_parser from open+call_id+args, handling atomic wrapping and close.
|
||||
// atomic_peek: if present, used as the peek expression in the third atomicity branch.
|
||||
common_peg_parser build_func_parser(common_chat_peg_builder & p, const std::string & name,
|
||||
const common_peg_parser & call_id_section, bool have_call_id,
|
||||
const common_peg_parser & args,
|
||||
std::optional<common_peg_parser> atomic_peek) const;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -25,6 +25,9 @@ static const std::string ARG_SECOND = "BB_ARG_SND_BB";
|
||||
static const std::string USER_MSG = "U_USER_MSG Hello END_U";
|
||||
static const std::string ASSISTANT_MSG = "A_ASST_MSG I can help END_A";
|
||||
static const std::string THINKING_CONTENT = "REASON_PART I am thinking END_R";
|
||||
static const std::string CALL_ID_001 = "call00001";
|
||||
static const std::string CALL_ID_002 = "call00002";
|
||||
static const std::string CALL_ID_999 = "call99999";
|
||||
|
||||
static std::vector<std::function<void(const common_chat_template & tmpl, autoparser &)>> workarounds(
|
||||
{ // Old reasoning Qwen templates - they don't really display reasoning content, but we still want to
|
||||
@@ -103,6 +106,7 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
|
||||
analysis.tools.function.name_prefix = "<|tool▁sep|>";
|
||||
analysis.tools.format.per_call_end = "<|tool▁call▁end|>";
|
||||
analysis.tools.function.close = "```";
|
||||
LOG_DBG(ANSI_ORANGE "[Patch: DeepSeek-R1-Distill-Qwen]\n" ANSI_RESET);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -130,7 +134,7 @@ static json user_msg = json{
|
||||
{ "content", USER_MSG }
|
||||
};
|
||||
|
||||
static json build_tool_call(const std::string & name, const json & args, const std::string & id = "call00001") {
|
||||
static json build_tool_call(const std::string & name, const json & args, const std::string & id = CALL_ID_001) {
|
||||
return json{
|
||||
{ "id", id },
|
||||
{ "type", "function" },
|
||||
@@ -138,17 +142,17 @@ static json build_tool_call(const std::string & name, const json & args, const s
|
||||
};
|
||||
}
|
||||
|
||||
static json first_tool_call_zero_args = build_tool_call(FUN_FIRST, json::object(), "call00001");
|
||||
static json first_tool_call_one_arg = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "XXXX" }}, "call00001");
|
||||
static json first_tool_call_one_arg_other_val = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "YYYY" }}, "call00001");
|
||||
static json first_tool_call_other_arg = build_tool_call(FUN_FIRST, {{ ARG_SECOND, "YYYY" }}, "call00001");
|
||||
static json first_tool_call_zero_args = build_tool_call(FUN_FIRST, json::object(), CALL_ID_001);
|
||||
static json first_tool_call_one_arg = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "XXXX" }}, CALL_ID_001);
|
||||
static json first_tool_call_one_arg_other_val = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "YYYY" }}, CALL_ID_001);
|
||||
static json first_tool_call_other_arg = build_tool_call(FUN_FIRST, {{ ARG_SECOND, "YYYY" }}, CALL_ID_001);
|
||||
|
||||
static json first_tool_call =
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call00001");
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_001);
|
||||
static json second_tool_call =
|
||||
build_tool_call(FUN_SECOND, json{ { ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call00002");
|
||||
build_tool_call(FUN_SECOND, json{ { ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_002);
|
||||
static json first_tool_call_alt_id =
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call99999");
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_999);
|
||||
|
||||
template <typename T>
|
||||
static std::string mode_to_str(T mode) {
|
||||
@@ -187,6 +191,11 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
|
||||
LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str());
|
||||
LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str());
|
||||
LOG_DBG("func_close: '%s'\n", tools.function.close.c_str());
|
||||
LOG_DBG("call_id_prefix: '%s'\n", tools.call_id.prefix.c_str());
|
||||
LOG_DBG("call_id_suffix: '%s'\n", tools.call_id.suffix.c_str());
|
||||
LOG_DBG("call_id_pos: '%s'\n", mode_to_str(tools.call_id.pos).c_str());
|
||||
LOG_DBG("args_start: '%s'\n", tools.arguments.start.c_str());
|
||||
LOG_DBG("args_end: '%s'\n", tools.arguments.end.c_str());
|
||||
LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str());
|
||||
LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str());
|
||||
LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str());
|
||||
@@ -555,12 +564,15 @@ analyze_tools::analyze_tools(const common_chat_template & tmpl,
|
||||
if (caps.supports_parallel_tool_calls) {
|
||||
check_per_call_markers();
|
||||
}
|
||||
LOG_DBG(ANSI_ORANGE "Phase 3a: Function call analysis\n" ANSI_RESET);
|
||||
extract_function_markers();
|
||||
LOG_DBG(ANSI_ORANGE "Phase 3b: Argument analysis\n" ANSI_RESET);
|
||||
if (format.mode == tool_format::TAG_WITH_TAGGED) {
|
||||
analyze_arguments();
|
||||
}
|
||||
extract_argument_separator();
|
||||
extract_args_markers();
|
||||
LOG_DBG(ANSI_ORANGE "Phase 3c: Call id analysis\n" ANSI_RESET);
|
||||
extract_call_id_markers();
|
||||
}
|
||||
}
|
||||
@@ -951,8 +963,6 @@ void analyze_tools::extract_function_markers() {
|
||||
}
|
||||
|
||||
void analyze_tools::analyze_arguments() {
|
||||
LOG_DBG(ANSI_ORANGE "Phase 4: Argument analysis\n" ANSI_RESET);
|
||||
|
||||
extract_argument_name_markers();
|
||||
extract_argument_value_markers();
|
||||
}
|
||||
@@ -1161,7 +1171,7 @@ void analyze_tools::extract_args_markers() {
|
||||
|
||||
const auto & diff = comparison->diff;
|
||||
|
||||
if (format.mode != tool_format::JSON_NATIVE) {
|
||||
if (format.mode == tool_format::JSON_NATIVE) {
|
||||
std::string prefix_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
|
||||
std::string suffix_marker = !format.section_end.empty() ? format.section_end : format.per_call_end;
|
||||
// these might happen earlier in the tools section as an example or somewhere else, so we need to find the closest ones
|
||||
@@ -1183,6 +1193,10 @@ void analyze_tools::extract_args_markers() {
|
||||
if (find_fun != std::string::npos) {
|
||||
args_start = args_start.substr(find_fun + FUN_FIRST.size(), args_start.size() - find_fun - FUN_FIRST.size());
|
||||
}
|
||||
size_t find_call_id = args_start.find(CALL_ID_001);
|
||||
if (find_call_id != std::string::npos) {
|
||||
args_start = args_start.substr(find_call_id + CALL_ID_001.size(), args_start.size() - find_call_id - CALL_ID_001.size());
|
||||
}
|
||||
arguments.start = args_start;
|
||||
arguments.end = args_end;
|
||||
}
|
||||
@@ -1222,8 +1236,8 @@ void analyze_tools::extract_call_id_markers() {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string id_value_1 = "call00001";
|
||||
std::string id_value_2 = "call99999";
|
||||
std::string id_value_1 = CALL_ID_001;
|
||||
std::string id_value_2 = CALL_ID_999;
|
||||
|
||||
size_t common_id_prefix_len = 0;
|
||||
for (size_t i = 0; i < std::min(id_value_1.length(), id_value_2.length()); i++) {
|
||||
@@ -1322,6 +1336,14 @@ void analyze_tools::extract_call_id_markers() {
|
||||
call_id.suffix = find_first_marker(before_func);
|
||||
}
|
||||
|
||||
if (call_id.prefix == arguments.end) {
|
||||
call_id.prefix = "";
|
||||
}
|
||||
|
||||
if (call_id.suffix == arguments.start) {
|
||||
call_id.suffix = "";
|
||||
}
|
||||
|
||||
// When call_id is detected, per_call_end may have been incorrectly set to include
|
||||
// the call_id_suffix and sample args. Clear it if it starts with call_id_suffix.
|
||||
if (call_id.pos != call_id_position::NONE && !call_id.suffix.empty() &&
|
||||
|
||||
@@ -214,6 +214,10 @@ std::string & common_chat_peg_mapper::args_target() {
|
||||
return (current_tool && !current_tool->name.empty()) ? current_tool->arguments : args_buffer;
|
||||
}
|
||||
|
||||
std::string common_chat_peg_mapper::normalize_container_value(const std::string & input) {
|
||||
return normalize_quotes_to_json(input);
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena,
|
||||
const common_peg_parse_result & parse_result_arg) {
|
||||
arena.visit(parse_result_arg, [this](const common_peg_ast_node & node) { map(node); });
|
||||
@@ -352,7 +356,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
||||
// For potential containers, normalize Python-style single quotes to JSON double quotes
|
||||
bool is_potential_container = value_content[0] == '[' || value_content[0] == '{';
|
||||
if (is_potential_container) {
|
||||
value_content = normalize_quotes_to_json(value_content);
|
||||
value_content = normalize_container_value(value_content);
|
||||
}
|
||||
|
||||
// Try to parse as JSON value (number, bool, null, object, array)
|
||||
@@ -861,3 +865,143 @@ common_peg_parser common_chat_peg_builder::standard_json_tools(
|
||||
|
||||
return force_tool_calls ? section : optional(section);
|
||||
}
|
||||
|
||||
void common_chat_peg_gemma4_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) {
|
||||
for (const auto & node : result.nodes) {
|
||||
visit(arena, node);
|
||||
}
|
||||
}
|
||||
|
||||
static std::string gemma4_to_json(const common_peg_ast_arena & arena, common_peg_ast_id id) {
|
||||
const auto & node = arena.get(id);
|
||||
|
||||
if (node.text.empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-number" || node.rule == "gemma4-bool" || node.rule == "gemma4-null") {
|
||||
return std::string(node.text);
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-string-content") {
|
||||
return escape_json_string_inner(std::string(node.text));
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-string") {
|
||||
std::string result = "\"";
|
||||
if (!node.children.empty()) {
|
||||
result += gemma4_to_json(arena, node.children[0]);
|
||||
if (!node.is_partial) {
|
||||
result += "\"";
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-array") {
|
||||
std::string result = "[";
|
||||
|
||||
bool add_comma = false;
|
||||
for (auto child_id : node.children) {
|
||||
if (add_comma) {
|
||||
result += ',';
|
||||
}
|
||||
add_comma = true;
|
||||
result += gemma4_to_json(arena, child_id);
|
||||
}
|
||||
|
||||
if (!node.is_partial) {
|
||||
result += ']';
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-dict-key-name") {
|
||||
return std::string(node.text);
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-dict-key") {
|
||||
std::string result = "\"";
|
||||
if (!node.children.empty()) {
|
||||
result += escape_json_string_inner(gemma4_to_json(arena, node.children[0]));
|
||||
}
|
||||
if (!node.is_partial) {
|
||||
result += "\":";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-dict-kv") {
|
||||
std::string result;
|
||||
for (auto child_id : node.children) {
|
||||
result += gemma4_to_json(arena, child_id);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-dict") {
|
||||
std::string result = "{";
|
||||
|
||||
bool add_comma = false;
|
||||
for (auto child_id : node.children) {
|
||||
if (add_comma) {
|
||||
result += ',';
|
||||
}
|
||||
add_comma = true;
|
||||
result += gemma4_to_json(arena, child_id);
|
||||
}
|
||||
|
||||
if (!node.is_partial) {
|
||||
result += '}';
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
if (node.rule == "gemma4-value") {
|
||||
if (!node.children.empty()) {
|
||||
return gemma4_to_json(arena, node.children[0]);
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
return "";
|
||||
}
|
||||
|
||||
void common_chat_peg_gemma4_mapper::visit(const common_peg_ast_arena & arena, common_peg_ast_id id) {
|
||||
const auto & node = arena.get(id);
|
||||
|
||||
if (node.tag == "reasoning") {
|
||||
result.reasoning_content += std::string(node.text);
|
||||
return;
|
||||
}
|
||||
|
||||
if (node.tag == "content") {
|
||||
result.content += std::string(node.text);
|
||||
return;
|
||||
}
|
||||
|
||||
if (node.tag == "tool") {
|
||||
auto name_id = arena.find_by_tag(node, "tool-name");
|
||||
auto args_id = arena.find_by_tag(node, "tool-args");
|
||||
|
||||
if (name_id != COMMON_PEG_INVALID_AST_ID && args_id != COMMON_PEG_INVALID_AST_ID) {
|
||||
const auto & name_node = arena.get(name_id);
|
||||
const auto & args_node = arena.get(args_id);
|
||||
|
||||
if (!name_node.is_partial) {
|
||||
common_chat_tool_call call;
|
||||
call.name = std::string(name_node.text);
|
||||
if (!args_node.children.empty()) {
|
||||
call.arguments = gemma4_to_json(arena, args_node.children[0]);
|
||||
}
|
||||
result.tool_calls.push_back(call);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto child_id : node.children) {
|
||||
visit(arena, child_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,9 @@ class common_chat_peg_mapper {
|
||||
|
||||
virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
|
||||
virtual void map(const common_peg_ast_node & node);
|
||||
private:
|
||||
protected:
|
||||
virtual std::string normalize_container_value(const std::string & input);
|
||||
private:
|
||||
// Tool call handling state
|
||||
std::optional<common_chat_tool_call> pending_tool_call; // Tool call waiting for name
|
||||
common_chat_tool_call * current_tool = nullptr;
|
||||
@@ -30,6 +32,14 @@ class common_chat_peg_mapper {
|
||||
std::string & args_target();
|
||||
};
|
||||
|
||||
class common_chat_peg_gemma4_mapper : public common_chat_peg_mapper {
|
||||
public:
|
||||
common_chat_peg_gemma4_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
||||
virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
|
||||
private:
|
||||
void visit(const common_peg_ast_arena & arena, common_peg_ast_id id);
|
||||
};
|
||||
|
||||
struct content_structure;
|
||||
struct tool_call_structure;
|
||||
|
||||
|
||||
452
common/chat.cpp
452
common/chat.cpp
@@ -13,6 +13,8 @@
|
||||
#include "jinja/caps.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <ctime>
|
||||
@@ -694,6 +696,8 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
return "peg-simple";
|
||||
case COMMON_CHAT_FORMAT_PEG_NATIVE:
|
||||
return "peg-native";
|
||||
case COMMON_CHAT_FORMAT_PEG_GEMMA4:
|
||||
return "peg-gemma4";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
@@ -760,12 +764,12 @@ static void foreach_parameter(const json &
|
||||
}
|
||||
}
|
||||
|
||||
std::string common_chat_template_direct_apply(
|
||||
static std::string common_chat_template_direct_apply_impl(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override,
|
||||
const std::optional<json> & tools_override,
|
||||
const std::optional<json> & additional_context) {
|
||||
const std::optional<json> & messages_override = std::nullopt,
|
||||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt) {
|
||||
jinja::context ctx(tmpl.source());
|
||||
|
||||
nlohmann::ordered_json inp = nlohmann::ordered_json{
|
||||
@@ -812,6 +816,12 @@ std::string common_chat_template_direct_apply(
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
return common_chat_template_direct_apply_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
@@ -862,7 +872,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "[THINK]";
|
||||
data.thinking_end_tag = "[/THINK]";
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
"[THINK]",
|
||||
@@ -945,7 +955,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
|
||||
auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
auto prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
|
||||
// Check if we need to replace the return token with end token during
|
||||
// inference and without generation prompt. For more details see:
|
||||
@@ -980,15 +990,19 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
auto channel = p.literal("<|channel|>") + (p.literal("commentary") | p.literal("analysis"));
|
||||
auto constrain_type = p.chars("[A-Za-z0-9_-]", 1, -1);
|
||||
|
||||
// Occasionally, gpt-oss-20b will prefix channels with this commentary
|
||||
auto stray_commentary = p.optional(p.literal("<|channel|>commentary") + p.optional(p.literal(" to=assistant")));
|
||||
auto start_analysis = stray_commentary + p.literal("<|channel|>analysis<|message|>");
|
||||
|
||||
if (extract_reasoning) {
|
||||
p.rule("analysis", p.literal("<|channel|>analysis<|message|>") + p.reasoning(content) + end);
|
||||
p.rule("analysis", start_analysis + p.reasoning(content) + end);
|
||||
} else {
|
||||
p.rule("analysis", p.content(p.literal("<|channel|>analysis<|message|>") + content + end));
|
||||
p.rule("analysis", p.content(start_analysis + content + end));
|
||||
}
|
||||
|
||||
auto analysis = p.ref("analysis");
|
||||
auto preamble = p.rule("preamble", p.literal("<|channel|>commentary<|message|>") + p.content(content) + end);
|
||||
auto final_msg = p.rule("final", p.literal("<|channel|>final<|message|>") + p.content(content));
|
||||
auto final_msg = p.rule("final", stray_commentary + p.literal("<|channel|>final<|message|>") + p.content(content));
|
||||
|
||||
// Consume any unsolicited tool calls, e.g. builtin functions
|
||||
auto unsolicited = p.rule("unsolicited", p.atomic(p.optional(channel) + p.literal(" to=") + content + end));
|
||||
@@ -996,7 +1010,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
auto any = p.rule("any", preamble | analysis);
|
||||
|
||||
if (has_response_format) {
|
||||
auto constraint = p.optional(p.space() + p.literal("<|constrain|>") + constrain_type);
|
||||
auto constraint = p.optional(p.space() + p.optional(p.literal("<|constrain|>")) + constrain_type);
|
||||
auto response_format = p.rule("response-format",
|
||||
p.literal("<|channel|>final") + constraint + p.literal("<|message|>") +
|
||||
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
|
||||
@@ -1013,7 +1027,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
const auto & params = function.at("parameters");
|
||||
|
||||
auto func_name = p.literal(" to=functions.") + p.tool_name(p.literal(name));
|
||||
auto constraint = p.optional(p.space() + p.literal("<|constrain|>") + constrain_type);
|
||||
auto constraint = p.optional(p.space() + p.optional(p.literal("<|constrain|>")) + constrain_type);
|
||||
auto args = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", params));
|
||||
|
||||
// recipient in role header
|
||||
@@ -1054,6 +1068,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^\\s+to$" },
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^<\\|channel\\|>(?:commentary|analysis)\\s+to=functions$" },
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "<\\|start\\|>assistant(\\s+to)" },
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "<\\|start\\|>assistant(<\\|channel\\|>(?:commentary|analysis)\\s+to)" }
|
||||
};
|
||||
@@ -1062,12 +1077,137 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_gemma4(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<|channel>",
|
||||
"<channel|>",
|
||||
"<|tool_call>",
|
||||
"<tool_call|>",
|
||||
"<|turn>",
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto start = p.rule("start", p.prefix(inputs.generation_prompt, "<|channel>"));
|
||||
|
||||
if (extract_reasoning) {
|
||||
p.rule("thought", p.literal("<|channel>thought\n") + p.reasoning(p.until("<channel|>")) + p.literal("<channel|>"));
|
||||
} else {
|
||||
p.rule("thought", p.content(p.literal("<|channel>thought\n") + p.until("<channel|>") + p.literal("<channel|>")));
|
||||
}
|
||||
|
||||
auto thought = (p.peek(p.literal("<|channel>")) + p.ref("thought")) | p.negate(p.literal("<|channel>"));
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.literal("```json") <<
|
||||
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)) <<
|
||||
p.literal("```");
|
||||
return start + p.optional(thought) + response_format;
|
||||
}
|
||||
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
// Gemma4 tool calling syntax
|
||||
// Rules should match traversal logic in gemma4_to_json()
|
||||
p.rule("gemma4-string-content", p.until("<|\"|>"));
|
||||
p.rule("gemma4-string", p.literal("<|\"|>") + p.ref("gemma4-string-content") + p.literal("<|\"|>"));
|
||||
p.rule("gemma4-bool", p.json_bool());
|
||||
p.rule("gemma4-null", p.json_null());
|
||||
p.rule("gemma4-number", p.json_number());
|
||||
p.rule("gemma4-dict-key", p.rule("gemma4-dict-key-name", p.until(":")) + p.literal(":"));
|
||||
p.rule("gemma4-dict-kv", p.ref("gemma4-dict-key") + p.space() + p.ref("gemma4-value"));
|
||||
p.rule("gemma4-dict", [&]() {
|
||||
auto ws = p.space();
|
||||
auto member = p.ref("gemma4-dict-kv");
|
||||
auto members = p.sequence({member, p.zero_or_more(p.sequence({p.literal(","), ws, member}))});
|
||||
return p.sequence({
|
||||
p.literal("{"), ws,
|
||||
p.choice({p.literal("}"), p.sequence({members, ws, p.literal("}")})})
|
||||
});
|
||||
});
|
||||
p.rule("gemma4-array", [&]() {
|
||||
auto ws = p.space();
|
||||
auto value = p.ref("gemma4-value");
|
||||
auto elements = p.sequence({value, p.zero_or_more(p.sequence({p.literal(","), ws, value}))});
|
||||
return p.sequence({
|
||||
p.literal("["), ws,
|
||||
p.choice({p.literal("]"), p.sequence({elements, ws, p.literal("]")})})
|
||||
});
|
||||
});
|
||||
p.rule("gemma4-value", [&]() {
|
||||
return p.choice({
|
||||
p.ref("gemma4-string"), p.ref("gemma4-dict"), p.ref("gemma4-array"),
|
||||
p.ref("gemma4-number"), p.ref("gemma4-bool"), p.ref("gemma4-null")
|
||||
});
|
||||
});
|
||||
|
||||
auto tool_choice = p.choice();
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
// TODO @aldehir : need to extend json-schema-to-grammar to produce more than JSON rules
|
||||
// const auto & params = function.at("parameters");
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, p.tool(p.sequence({
|
||||
p.tool_open(p.tool_name(p.literal(name)) + p.peek(p.literal("{"))),
|
||||
p.tool_args(p.ref("gemma4-dict")),
|
||||
})));
|
||||
});
|
||||
|
||||
auto tool_call = p.trigger_rule("tool-call", p.repeat(
|
||||
"<|tool_call>call:" + tool_choice + "<tool_call|>",
|
||||
/* min = */ inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0,
|
||||
/* max = */ inputs.parallel_tool_calls ? -1 : 1
|
||||
));
|
||||
|
||||
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<|tool_call>"})));
|
||||
auto message = p.rule("message", thought + content);
|
||||
return start + p.zero_or_more(message) + tool_call;
|
||||
}
|
||||
|
||||
auto content = p.rule("content", p.content(p.until("<|channel>")));
|
||||
auto message = p.rule("message", thought + content);
|
||||
return start + p.one_or_more(message);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_call>" },
|
||||
};
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
// Functionary v3.2 - uses recipient-based format: >>>recipient\n{content}
|
||||
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
">>>all",
|
||||
@@ -1161,7 +1301,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
@@ -1274,16 +1414,17 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
return data;
|
||||
}
|
||||
|
||||
// LFM2 format:
|
||||
// - Reasoning: <think>{reasoning}</think> (optional, only if enable_thinking is true)
|
||||
// - Content: text after reasoning (optional)
|
||||
// - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|>
|
||||
// Tool calls can appear multiple times (parallel tool calls)
|
||||
// LFM2 format: uses <|tool_list_start|>[...]<|tool_list_end|> in system prompt
|
||||
// and <|tool_call_start|>[name(arg="val")]<|tool_call_end|> for tool calls.
|
||||
// - Reasoning: <think>{reasoning}</think> (optional)
|
||||
// - Content: text before a tool call (optional)
|
||||
// - Tool calls: Python-style, e.g. [function_name(arg1="value1", arg2="value2")]
|
||||
// Tool calls can appear multiple times (parallel tool calls supported)
|
||||
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
@@ -1319,9 +1460,9 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return generation_prompt + reasoning + p.content(p.rest()) + end;
|
||||
}
|
||||
|
||||
auto tool_calls = p.rule("tool-calls",
|
||||
p.trigger_rule("tool-call", p.literal(TOOL_CALL_START) +
|
||||
p.trigger_rule("tool-call",
|
||||
p.literal(TOOL_CALL_START) +
|
||||
p.python_style_tool_calls(inputs.tools, inputs.parallel_tool_calls) +
|
||||
p.literal(TOOL_CALL_END)
|
||||
)
|
||||
@@ -1349,6 +1490,80 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, TOOL_CALL_START }
|
||||
};
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
// LFM2.5 format: uses plain "List of tools: [...]" in system prompt, no wrapper tokens.
|
||||
// Tool calls are bare [name(arg="val")], though model may optionally emit <|tool_call_start|>.
|
||||
// - Reasoning: <think>{reasoning}</think> (optional)
|
||||
// - Content: text before a tool call (optional)
|
||||
// - Tool calls: Python-style, e.g. [function_name(arg1="value1", arg2="value2")]
|
||||
// Tool calls can appear multiple times (parallel tool calls supported)
|
||||
static common_chat_params common_chat_params_init_lfm2_5(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
"<|tool_call_start|>",
|
||||
"<|tool_call_end|>",
|
||||
"<think>",
|
||||
"</think>",
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
if (extract_reasoning && inputs.enable_thinking) {
|
||||
reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END);
|
||||
}
|
||||
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return generation_prompt + reasoning + p.content(p.rest()) + end;
|
||||
}
|
||||
|
||||
auto tool_calls = p.rule("tool-calls",
|
||||
p.trigger_rule("tool-call",
|
||||
p.python_style_tool_calls(inputs.tools, inputs.parallel_tool_calls)
|
||||
)
|
||||
);
|
||||
|
||||
auto content = p.content(p.until_one_of({"<|tool_call_start|>", "["}));
|
||||
auto maybe_start = p.optional(p.literal("<|tool_call_start|>"));
|
||||
return generation_prompt + reasoning + content + maybe_start + tool_calls + end;
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const std::string name = tool.at("function").at("name");
|
||||
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[" + name + "(" });
|
||||
});
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
@@ -1359,7 +1574,7 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = false;
|
||||
data.preserved_tokens = {
|
||||
@@ -1465,6 +1680,150 @@ static void requires_non_null_content(json & messages) {
|
||||
}
|
||||
}
|
||||
|
||||
// Gemma4 uses a custom tool_responses field instead of role:tool messages.
|
||||
//
|
||||
// This will transform a sequence of messages:
|
||||
// assistant(tool_call+) -> tool+ -> assistant(content)
|
||||
//
|
||||
// Into a single assistant message containing a tool_responses field:
|
||||
// assistant(content + tool_call + tool_responses)
|
||||
//
|
||||
// This is necessary for the Gemma4 chat template to properly format the prompt.
|
||||
// See https://ai.google.dev/gemma/docs/core/prompt-formatting-gemma4
|
||||
struct gemma4_model_turn_builder {
|
||||
json & messages;
|
||||
size_t pos;
|
||||
json tool_calls = json::array();
|
||||
json tool_responses = json::array();
|
||||
json content;
|
||||
json reasoning_content;
|
||||
|
||||
gemma4_model_turn_builder(json & msgs, size_t pos) : messages(msgs), pos(pos) {}
|
||||
|
||||
void collect() {
|
||||
// Collect the first assistant message
|
||||
auto & msg = messages[pos];
|
||||
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
|
||||
// According to the prompt formatting guide, we need to preserve reasoning_content
|
||||
// between function calls. The current chat templates do not support this, but we will do it anyway.
|
||||
reasoning_content = msg.at("reasoning_content");
|
||||
}
|
||||
for (auto & tc : msg.at("tool_calls")) {
|
||||
tool_calls.push_back(tc);
|
||||
}
|
||||
pos++;
|
||||
|
||||
// Collect tool call results
|
||||
while (pos < messages.size() && messages[pos].value("role", "") == "tool") {
|
||||
collect_result(messages[pos]);
|
||||
pos++;
|
||||
}
|
||||
|
||||
// Check if the next assistant message is the final message
|
||||
if (pos < messages.size() && messages[pos].value("role", "") == "assistant") {
|
||||
auto & next = messages[pos];
|
||||
if (!has_tool_calls(next) && has_content(next)) {
|
||||
content = next.at("content");
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void collect_result(const json & curr) {
|
||||
json response;
|
||||
if (curr.contains("content")) {
|
||||
const auto & content = curr.at("content");
|
||||
if (content.is_string()) {
|
||||
// Try to parse the content as JSON; fall back to raw string
|
||||
try {
|
||||
response = json::parse(content.get<std::string>());
|
||||
} catch (...) {
|
||||
response = content;
|
||||
}
|
||||
} else {
|
||||
response = content;
|
||||
}
|
||||
}
|
||||
|
||||
std::string name;
|
||||
|
||||
// Match name with corresponding tool call
|
||||
size_t idx = tool_responses.size();
|
||||
if (idx < tool_calls.size()) {
|
||||
auto & tc = tool_calls[idx];
|
||||
if (tc.contains("function")) {
|
||||
name = tc.at("function").value("name", "");
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to the tool call id
|
||||
if (name.empty()) {
|
||||
name = curr.value("tool_call_id", "");
|
||||
}
|
||||
|
||||
tool_responses.push_back({{"name", name}, {"response", response}});
|
||||
}
|
||||
|
||||
json build() {
|
||||
collect();
|
||||
|
||||
json msg = {
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", tool_calls},
|
||||
};
|
||||
if (!tool_responses.empty()) {
|
||||
msg["tool_responses"] = tool_responses;
|
||||
}
|
||||
if (!content.is_null()) {
|
||||
msg["content"] = content;
|
||||
}
|
||||
if (!reasoning_content.is_null()) {
|
||||
msg["reasoning_content"] = reasoning_content;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
static bool has_content(const json & msg) {
|
||||
if (!msg.contains("content") || msg.at("content").is_null()) {
|
||||
return false;
|
||||
}
|
||||
const auto & content = msg.at("content");
|
||||
if (content.is_string() && !content.get<std::string>().empty()) {
|
||||
return true;
|
||||
}
|
||||
if (content.is_array() && !content.empty()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool has_tool_calls(const json & msg) {
|
||||
return msg.contains("tool_calls") && msg.at("tool_calls").is_array() && !msg.at("tool_calls").empty();
|
||||
}
|
||||
};
|
||||
|
||||
static void convert_tool_responses_gemma4(json & messages) {
|
||||
json result = json::array();
|
||||
size_t i = 0;
|
||||
|
||||
while (i < messages.size()) {
|
||||
auto & msg = messages[i];
|
||||
|
||||
if (msg.value("role", "") != "assistant" || !msg.contains("tool_calls") ||
|
||||
!msg.at("tool_calls").is_array() || msg.at("tool_calls").empty()) {
|
||||
result.push_back(msg);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
gemma4_model_turn_builder builder(messages, i);
|
||||
result.push_back(builder.build());
|
||||
i = builder.pos;
|
||||
}
|
||||
|
||||
messages = result;
|
||||
}
|
||||
|
||||
static void func_args_not_string(json & messages) {
|
||||
GGML_ASSERT(messages.is_array());
|
||||
for (auto & message : messages) {
|
||||
@@ -1497,10 +1856,10 @@ static json common_chat_extra_context() {
|
||||
return ctx;
|
||||
}
|
||||
|
||||
static std::optional<common_chat_params> try_specialized_template(
|
||||
std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::string & src,
|
||||
const autoparser::generation_params & params) {
|
||||
autoparser::generation_params & params) {
|
||||
// Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser
|
||||
// Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them
|
||||
if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos &&
|
||||
@@ -1530,14 +1889,21 @@ static std::optional<common_chat_params> try_specialized_template(
|
||||
return common_chat_params_init_kimi_k2(tmpl, params);
|
||||
}
|
||||
|
||||
// LFM2 - uses <|tool_list_start|>/<|tool_list_end|> markers and <|tool_call_start|>[name(args)]<|tool_call_end|> format
|
||||
// Detection: template has "<|tool_list_start|>" and "<|tool_list_end|>" markers
|
||||
// LFM2 format detection: template uses <|tool_list_start|>[...]<|tool_list_end|> around the tool list
|
||||
// and <|tool_call_start|>[...]<|tool_call_end|> around each tool call
|
||||
if (src.find("<|tool_list_start|>") != std::string::npos &&
|
||||
src.find("<|tool_list_end|>") != std::string::npos) {
|
||||
LOG_DBG("Using specialized template: LFM2\n");
|
||||
return common_chat_params_init_lfm2(tmpl, params);
|
||||
}
|
||||
|
||||
// LFM2.5 format detection: template uses plain "List of tools: [...]" with no special tokens
|
||||
if (src.find("List of tools: [") != std::string::npos &&
|
||||
src.find("<|tool_list_start|>") == std::string::npos) {
|
||||
LOG_DBG("Using specialized template: LFM2.5\n");
|
||||
return common_chat_params_init_lfm2_5(tmpl, params);
|
||||
}
|
||||
|
||||
// GigaChatV3 format detection
|
||||
if (src.find("<|role_sep|>") != std::string::npos &&
|
||||
src.find("<|message_sep|>") != std::string::npos &&
|
||||
@@ -1546,6 +1912,12 @@ static std::optional<common_chat_params> try_specialized_template(
|
||||
return common_chat_params_init_gigachat_v3(tmpl, params);
|
||||
}
|
||||
|
||||
// Gemma4 format detection
|
||||
if (src.find("'<|tool_call>call:'") != std::string::npos) {
|
||||
workaround::convert_tool_responses_gemma4(params.messages);
|
||||
return common_chat_params_init_gemma4(tmpl, params);
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
@@ -1587,9 +1959,9 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
}
|
||||
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
|
||||
params.generation_prompt = diff.right;
|
||||
|
||||
@@ -1623,17 +1995,17 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
common_chat_params data;
|
||||
auto params_copy = params;
|
||||
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, params_copy);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.generation_prompt = params.generation_prompt;
|
||||
auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) {
|
||||
return p.prefix(params.generation_prompt) + p.content(p.rest());
|
||||
return p.prefix(params.generation_prompt) << p.content(p.rest());
|
||||
});
|
||||
data.parser = parser.save();
|
||||
return data;
|
||||
}
|
||||
|
||||
if (auto result = try_specialized_template(tmpl, src, params)) {
|
||||
if (auto result = common_chat_try_specialized_template(tmpl, src, params)) {
|
||||
result->generation_prompt = params.generation_prompt;
|
||||
return *result;
|
||||
}
|
||||
@@ -1770,8 +2142,13 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
// Try to extract any partial results from what was successfully parsed
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
std::unique_ptr<common_chat_peg_mapper> mapper;
|
||||
if (params.format == COMMON_CHAT_FORMAT_PEG_GEMMA4) {
|
||||
mapper = std::make_unique<common_chat_peg_gemma4_mapper>(msg);
|
||||
} else {
|
||||
mapper = std::make_unique<common_chat_peg_mapper>(msg);
|
||||
}
|
||||
mapper->from_ast(ctx.ast, result);
|
||||
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "\nAST for partial parse (fail):\n%s\n", ctx.ast.dump().c_str());
|
||||
@@ -1786,8 +2163,13 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
std::unique_ptr<common_chat_peg_mapper> mapper;
|
||||
if (params.format == COMMON_CHAT_FORMAT_PEG_GEMMA4) {
|
||||
mapper = std::make_unique<common_chat_peg_gemma4_mapper>(msg);
|
||||
} else {
|
||||
mapper = std::make_unique<common_chat_peg_mapper>(msg);
|
||||
}
|
||||
mapper->from_ast(ctx.ast, result);
|
||||
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "\nAST for %s parse:\n%s\n", is_partial ? "partial" : "full", ctx.ast.dump().c_str());
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "jinja/parser.h"
|
||||
#include "nlohmann/json_fwd.hpp"
|
||||
#include "peg-parser.h"
|
||||
#include "jinja/parser.h"
|
||||
#include "jinja/runtime.h"
|
||||
#include "jinja/caps.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include "nlohmann/json_fwd.hpp"
|
||||
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
@@ -19,8 +19,6 @@
|
||||
using chat_template_caps = jinja::caps;
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
struct common_chat_templates;
|
||||
|
||||
namespace autoparser {
|
||||
@@ -75,41 +73,9 @@ struct common_chat_template {
|
||||
const std::string & bos_token() const { return bos_tok; }
|
||||
const std::string & eos_token() const { return eos_tok; }
|
||||
|
||||
// TODO: this is ugly, refactor it somehow
|
||||
json add_system(const json & messages, const std::string & system_prompt) const {
|
||||
GGML_ASSERT(messages.is_array());
|
||||
auto msgs_copy = messages;
|
||||
if (!caps.supports_system_role) {
|
||||
if (msgs_copy.empty()) {
|
||||
msgs_copy.insert(msgs_copy.begin(), json{
|
||||
{"role", "user"},
|
||||
{"content", system_prompt}
|
||||
});
|
||||
} else {
|
||||
auto & first_msg = msgs_copy[0];
|
||||
if (!first_msg.contains("content")) {
|
||||
first_msg["content"] = "";
|
||||
}
|
||||
first_msg["content"] = system_prompt + "\n\n"
|
||||
+ first_msg["content"].get<std::string>();
|
||||
}
|
||||
} else {
|
||||
if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") {
|
||||
msgs_copy.insert(msgs_copy.begin(), json{
|
||||
{"role", "system"},
|
||||
{"content", system_prompt}
|
||||
});
|
||||
} else if (msgs_copy[0].at("role") == "system") {
|
||||
msgs_copy[0]["content"] = system_prompt;
|
||||
}
|
||||
}
|
||||
return msgs_copy;
|
||||
}
|
||||
|
||||
chat_template_caps original_caps() const {
|
||||
return caps;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct common_chat_msg {
|
||||
@@ -184,6 +150,7 @@ enum common_chat_format {
|
||||
// These are intended to be parsed by the PEG parser
|
||||
COMMON_CHAT_FORMAT_PEG_SIMPLE,
|
||||
COMMON_CHAT_FORMAT_PEG_NATIVE,
|
||||
COMMON_CHAT_FORMAT_PEG_GEMMA4,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
@@ -256,8 +223,8 @@ common_chat_templates_ptr common_chat_templates_init(const struct llama_model *
|
||||
const std::string & bos_token_override = "",
|
||||
const std::string & eos_token_override = "");
|
||||
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
|
||||
|
||||
struct common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs);
|
||||
@@ -274,9 +241,9 @@ std::string common_chat_format_example(const struct common_chat_templates *
|
||||
bool use_jinja,
|
||||
const std::map<std::string, std::string> & chat_template_kwargs);
|
||||
|
||||
const char * common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
const char * common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
|
||||
// used by arg and server
|
||||
const char * common_reasoning_format_name(common_reasoning_format format);
|
||||
@@ -302,7 +269,9 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
|
||||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override = std::nullopt,
|
||||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt);
|
||||
const autoparser::generation_params & inputs);
|
||||
|
||||
std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::string & src,
|
||||
autoparser::generation_params & params);
|
||||
|
||||
@@ -1442,6 +1442,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||
|
||||
mparams.progress_callback = params.load_progress_callback;
|
||||
mparams.progress_callback_user_data = params.load_progress_callback_user_data;
|
||||
mparams.no_alloc = params.no_alloc;
|
||||
|
||||
return mparams;
|
||||
}
|
||||
|
||||
@@ -579,8 +579,9 @@ struct common_params {
|
||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||
bool cache_prompt = true; // whether to enable prompt caching
|
||||
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
|
||||
int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill
|
||||
bool clear_idle = true; // save and clear idle slots upon starting a new task
|
||||
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
|
||||
int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
@@ -679,6 +680,7 @@ struct common_params {
|
||||
// return false from callback to abort model loading or true to continue
|
||||
llama_progress_callback load_progress_callback = NULL;
|
||||
void * load_progress_callback_user_data = NULL;
|
||||
bool no_alloc = false; // Don't allocate model buffers
|
||||
};
|
||||
|
||||
// call once at the start of a program if it uses libcommon
|
||||
|
||||
@@ -596,9 +596,12 @@ static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & f : files) {
|
||||
if (gguf_filename_is_model(f.path)) {
|
||||
return f;
|
||||
// fallback to first available model only if tag is empty
|
||||
if (tag.empty()) {
|
||||
for (const auto & f : files) {
|
||||
if (gguf_filename_is_model(f.path)) {
|
||||
return f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -306,6 +306,19 @@ value filter_expression::execute_impl(context & ctx) {
|
||||
filter_id = "strip"; // alias
|
||||
}
|
||||
JJ_DEBUG("Applying filter '%s' to %s", filter_id.c_str(), input->type().c_str());
|
||||
// TODO: Refactor filters so this coercion can be done automatically
|
||||
if (!input->is_undefined() && !is_val<value_string>(input) && (
|
||||
filter_id == "capitalize" ||
|
||||
filter_id == "lower" ||
|
||||
filter_id == "replace" ||
|
||||
filter_id == "strip" ||
|
||||
filter_id == "title" ||
|
||||
filter_id == "upper" ||
|
||||
filter_id == "wordcount"
|
||||
)) {
|
||||
JJ_DEBUG("Coercing %s to String for '%s' filter", input->type().c_str(), filter_id.c_str());
|
||||
input = mk_val<value_string>(input->as_string());
|
||||
}
|
||||
return try_builtin_func(ctx, filter_id, input)->invoke(func_args(ctx));
|
||||
|
||||
} else if (is_stmt<call_expression>(filter)) {
|
||||
|
||||
@@ -465,8 +465,9 @@ const func_builtins & value_int_t::get_builtins() const {
|
||||
double val = static_cast<double>(args.get_pos(0)->as_int());
|
||||
return mk_val<value_float>(val);
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"safe", tojson},
|
||||
{"string", tojson},
|
||||
{"tojson", tojson},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
@@ -485,8 +486,9 @@ const func_builtins & value_float_t::get_builtins() const {
|
||||
int64_t val = static_cast<int64_t>(args.get_pos(0)->as_float());
|
||||
return mk_val<value_int>(val);
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"safe", tojson},
|
||||
{"string", tojson},
|
||||
{"tojson", tojson},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
@@ -771,6 +773,11 @@ const func_builtins & value_string_t::get_builtins() const {
|
||||
|
||||
|
||||
const func_builtins & value_bool_t::get_builtins() const {
|
||||
static const func_handler tostring = [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_bool>();
|
||||
bool val = args.get_pos(0)->as_bool();
|
||||
return mk_val<value_string>(val ? "True" : "False");
|
||||
};
|
||||
static const func_builtins builtins = {
|
||||
{"default", default_value},
|
||||
{"int", [](const func_args & args) -> value {
|
||||
@@ -783,11 +790,8 @@ const func_builtins & value_bool_t::get_builtins() const {
|
||||
bool val = args.get_pos(0)->as_bool();
|
||||
return mk_val<value_float>(val ? 1.0 : 0.0);
|
||||
}},
|
||||
{"string", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_bool>();
|
||||
bool val = args.get_pos(0)->as_bool();
|
||||
return mk_val<value_string>(val ? "True" : "False");
|
||||
}},
|
||||
{"safe", tostring},
|
||||
{"string", tostring},
|
||||
{"tojson", tojson},
|
||||
};
|
||||
return builtins;
|
||||
@@ -1100,18 +1104,14 @@ const func_builtins & value_object_t::get_builtins() const {
|
||||
}
|
||||
|
||||
const func_builtins & value_none_t::get_builtins() const {
|
||||
static const func_handler tostring = [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
};
|
||||
static const func_builtins builtins = {
|
||||
{"default", default_value},
|
||||
{"tojson", tojson},
|
||||
{"string", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"safe", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"strip", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"string", tostring},
|
||||
{"safe", tostring},
|
||||
{"items", empty_value_fn<value_array>},
|
||||
{"map", empty_value_fn<value_array>},
|
||||
{"reject", empty_value_fn<value_array>},
|
||||
|
||||
@@ -256,6 +256,38 @@ static std::pair<std::vector<common_peg_chars_parser::char_range>, bool> parse_c
|
||||
return {ranges, negated};
|
||||
}
|
||||
|
||||
common_peg_ast_id common_peg_ast_arena::find_by_tag(const common_peg_ast_node & parent, const std::string & tag, int max_depth) const {
|
||||
for (auto child_id : parent.children) {
|
||||
const auto & child = get(child_id);
|
||||
if (child.tag == tag) {
|
||||
return child_id;
|
||||
}
|
||||
if (max_depth > 1) {
|
||||
auto result = find_by_tag(child, tag, max_depth - 1);
|
||||
if (result != COMMON_PEG_INVALID_AST_ID) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
return COMMON_PEG_INVALID_AST_ID;
|
||||
}
|
||||
|
||||
common_peg_ast_id common_peg_ast_arena::find_by_rule(const common_peg_ast_node & parent, const std::string & rule, int max_depth) const {
|
||||
for (auto child_id : parent.children) {
|
||||
const auto & child = get(child_id);
|
||||
if (child.rule == rule) {
|
||||
return child_id;
|
||||
}
|
||||
if (max_depth > 1) {
|
||||
auto result = find_by_rule(child, rule, max_depth - 1);
|
||||
if (result != COMMON_PEG_INVALID_AST_ID) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
return COMMON_PEG_INVALID_AST_ID;
|
||||
}
|
||||
|
||||
void common_peg_ast_arena::visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const {
|
||||
if (id == COMMON_PEG_INVALID_AST_ID) {
|
||||
return;
|
||||
@@ -1557,6 +1589,52 @@ static std::unordered_set<std::string> collect_reachable_rules(
|
||||
|
||||
// GBNF generation implementation
|
||||
void common_peg_arena::build_grammar(const common_grammar_builder & builder, bool lazy) const {
|
||||
auto schema_delegates = [](const common_peg_schema_parser & s) -> bool {
|
||||
if (!s.schema) {
|
||||
return true;
|
||||
}
|
||||
if (s.raw && s.schema->contains("type")) {
|
||||
const auto & type_val = s.schema->at("type");
|
||||
if (type_val.is_string() && type_val == "string") {
|
||||
return true;
|
||||
}
|
||||
// Handle nullable types like ["string", "null"] - delegate when the
|
||||
// non-null type is string, since the tagged format uses raw text
|
||||
if (type_val.is_array()) {
|
||||
for (const auto & t : type_val) {
|
||||
if (t.is_string() && t.get<std::string>() != "null") {
|
||||
return t.get<std::string>() == "string";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Delegate for enum schemas in raw mode - enum values are literal strings
|
||||
if (s.raw && !s.schema->contains("type") && s.schema->contains("enum")) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// Unwrap the parser so we can properly check if it's a sequence or choice
|
||||
auto effective_parser = [&](common_peg_parser_id id) -> const common_peg_parser_variant & {
|
||||
while (true) {
|
||||
const auto & p = parsers_.at(id);
|
||||
if (const auto * tag = std::get_if<common_peg_tag_parser>(&p)) {
|
||||
id = tag->child;
|
||||
} else if (const auto * atomic = std::get_if<common_peg_atomic_parser>(&p)) {
|
||||
id = atomic->child;
|
||||
} else if (const auto * schema = std::get_if<common_peg_schema_parser>(&p)) {
|
||||
if (schema_delegates(*schema)) {
|
||||
id = schema->child;
|
||||
} else {
|
||||
return p;
|
||||
}
|
||||
} else {
|
||||
return p;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Generate GBNF for a parser
|
||||
std::function<std::string(common_peg_parser_id)> to_gbnf = [&](common_peg_parser_id id) -> std::string {
|
||||
const auto & parser = parsers_.at(id);
|
||||
@@ -1577,7 +1655,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
s += " ";
|
||||
}
|
||||
auto child_gbnf = to_gbnf(child);
|
||||
const auto & child_parser = parsers_.at(child);
|
||||
const auto & child_parser = effective_parser(child);
|
||||
if (std::holds_alternative<common_peg_choice_parser>(child_parser) ||
|
||||
std::holds_alternative<common_peg_sequence_parser>(child_parser)) {
|
||||
s += "(" + child_gbnf + ")";
|
||||
@@ -1593,7 +1671,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
s += " | ";
|
||||
}
|
||||
auto child_gbnf = to_gbnf(child);
|
||||
const auto & child_parser = parsers_.at(child);
|
||||
const auto & child_parser = effective_parser(child);
|
||||
if (std::holds_alternative<common_peg_choice_parser>(child_parser)) {
|
||||
s += "(" + child_gbnf + ")";
|
||||
} else {
|
||||
@@ -1603,7 +1681,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
return s;
|
||||
} else if constexpr (std::is_same_v<T, common_peg_repetition_parser>) {
|
||||
auto child_gbnf = to_gbnf(p.child);
|
||||
const auto & child_parser = parsers_.at(p.child);
|
||||
const auto & child_parser = effective_parser(p.child);
|
||||
if (std::holds_alternative<common_peg_choice_parser>(child_parser) ||
|
||||
std::holds_alternative<common_peg_sequence_parser>(child_parser)) {
|
||||
child_gbnf = "(" + child_gbnf + ")";
|
||||
@@ -1663,15 +1741,10 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
}
|
||||
return gbnf_excluding_pattern(p.delimiters);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
|
||||
if (p.schema) {
|
||||
if (p.raw && p.schema->contains("type") && p.schema->at("type").is_string() && p.schema->at("type") == "string") {
|
||||
// TODO: Implement more comprehensive grammar generation for raw strings.
|
||||
// For now, use the grammar emitted from the underlying parser.
|
||||
return to_gbnf(p.child);
|
||||
}
|
||||
return builder.add_schema(p.name, *p.schema);
|
||||
if (schema_delegates(p)) {
|
||||
return to_gbnf(p.child);
|
||||
}
|
||||
return to_gbnf(p.child);
|
||||
return builder.add_schema(p.name, *p.schema);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
return p.name;
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ref_parser>) {
|
||||
|
||||
@@ -106,6 +106,9 @@ class common_peg_ast_arena {
|
||||
|
||||
const common_peg_ast_node & get(common_peg_ast_id id) const { return nodes_.at(id); }
|
||||
|
||||
common_peg_ast_id find_by_tag(const common_peg_ast_node & parent, const std::string & tag, int max_depth = 3) const;
|
||||
common_peg_ast_id find_by_rule(const common_peg_ast_node & parent, const std::string & tag, int max_depth = 3) const;
|
||||
|
||||
size_t size() const { return nodes_.size(); }
|
||||
|
||||
void clear() { nodes_.clear(); }
|
||||
|
||||
@@ -1164,7 +1164,7 @@ class TextModel(ModelBase):
|
||||
if (n_experts := self.find_hparam(["num_local_experts", "num_experts"], optional=True)) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
logger.info(f"gguf: expert count = {n_experts}")
|
||||
if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token"], optional=True)) is not None:
|
||||
if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token", "top_k_experts"], optional=True)) is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
logger.info(f"gguf: experts used count = {n_experts_used}")
|
||||
if (n_expert_groups := self.hparams.get("n_group")) is not None:
|
||||
@@ -6878,7 +6878,9 @@ class Gemma2Model(TextModel):
|
||||
@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
|
||||
class Gemma3Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA3
|
||||
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
|
||||
|
||||
def norm_shift(self, name: str) -> float:
|
||||
return 1.0 if name.endswith("norm.weight") else 0.0 # Gemma3RMSNorm adds 1.0 to the norm value
|
||||
|
||||
def set_vocab(self):
|
||||
if (self.dir_model / "tokenizer.model").is_file():
|
||||
@@ -6916,17 +6918,22 @@ class Gemma3Model(TextModel):
|
||||
|
||||
# remove OOV (out-of-vocabulary) rows in token_embd
|
||||
if "embed_tokens.weight" in name:
|
||||
n_vocab_real = -1
|
||||
if (self.dir_model / "tokenizer.model").is_file():
|
||||
tokens = self._create_vocab_sentencepiece()[0]
|
||||
n_vocab_real = len(tokens)
|
||||
else:
|
||||
tokens = self.get_vocab_base()[0]
|
||||
data_torch = data_torch[:len(tokens)]
|
||||
with open(self.dir_model / "tokenizer.json", "r", encoding="utf-8") as f:
|
||||
tokenizer_json = json.load(f)
|
||||
n_vocab_real = len(tokenizer_json["model"]["vocab"]) + len(tokenizer_json["added_tokens"])
|
||||
data_torch = data_torch[:n_vocab_real]
|
||||
|
||||
# ref code in Gemma3RMSNorm
|
||||
# output = output * (1.0 + self.weight.float())
|
||||
# note: this is not the case on gemma3n
|
||||
if name.endswith("norm.weight"):
|
||||
data_torch = data_torch + self.norm_shift
|
||||
f_shift = self.norm_shift(name)
|
||||
if f_shift != 0.0:
|
||||
data_torch = data_torch + f_shift
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
@@ -7100,7 +7107,8 @@ class ConformerAudioModel(MmprojModel):
|
||||
assert data_torch.shape[2] == 1
|
||||
data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[1])
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
mapped_name = self.map_tensor_name(name, (".weight", ".bias", ".input_max", ".input_min", ".output_max", ".output_min"))
|
||||
yield (mapped_name, data_torch)
|
||||
|
||||
|
||||
@ModelBase.register("DeepseekOCRForCausalLM")
|
||||
@@ -7289,7 +7297,6 @@ class Gemma3nVisionAudioModel(ConformerAudioModel):
|
||||
@ModelBase.register("Gemma3nForCausalLM", "Gemma3nForConditionalGeneration")
|
||||
class Gemma3NModel(Gemma3Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA3N
|
||||
norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
|
||||
|
||||
_altup_proj: list[Tensor] = []
|
||||
_altup_unembd: list[Tensor] = []
|
||||
@@ -7308,6 +7315,10 @@ class Gemma3NModel(Gemma3Model):
|
||||
torch.Tensor(), # to be replaced
|
||||
]
|
||||
|
||||
def norm_shift(self, name: str) -> float:
|
||||
del name
|
||||
return 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
|
||||
|
||||
def set_vocab(self):
|
||||
# For Gemma3n multimodal models, we need the FULL vocab_size (262400)
|
||||
# which includes special tokens from 262144-262399 for vision/audio.
|
||||
@@ -7425,6 +7436,209 @@ class Gemma3NModel(Gemma3Model):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Gemma4ForConditionalGeneration")
|
||||
class Gemma4Model(Gemma3Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA4
|
||||
|
||||
def norm_shift(self, name: str) -> float:
|
||||
del name # unused
|
||||
return 0.0
|
||||
|
||||
def set_vocab(self):
|
||||
vocab = gguf.LlamaHfVocab(self.dir_model)
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
visible_tokens = {"<|channel>", "<channel|>", "<|tool_call>", "<tool_call|>", "<|tool_response>", "<tool_response|>", "<|\"|>"}
|
||||
|
||||
for text, score, toktype in vocab.all_tokens():
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
text_str = text.decode()
|
||||
if text_str in visible_tokens:
|
||||
# always render these tokens, so that the chat parser can read them
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
logger.info(f"Token '{text_str}' is set to USER_DEFINED")
|
||||
else:
|
||||
toktypes.append(toktype)
|
||||
|
||||
assert len(tokens) == vocab.vocab_size
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("gemma4")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
num_kv_shared_layers = self.hparams["num_kv_shared_layers"]
|
||||
self.gguf_writer.add_shared_kv_layers(num_kv_shared_layers)
|
||||
|
||||
# per-layer embedding is optional
|
||||
n_pl_embd = self.hparams.get("hidden_size_per_layer_input") or 0
|
||||
self.gguf_writer.add_embedding_length_per_layer_input(n_pl_embd)
|
||||
|
||||
swa_layers = [t == "sliding_attention" for t in self.hparams["layer_types"]]
|
||||
self.gguf_writer.add_sliding_window_pattern(swa_layers)
|
||||
|
||||
head_dim_full = self.hparams["global_head_dim"]
|
||||
head_dim_swa = self.hparams["head_dim"]
|
||||
# correct the head dim for global/swa layers
|
||||
self.gguf_writer.add_key_length(head_dim_full)
|
||||
self.gguf_writer.add_value_length(head_dim_full)
|
||||
self.gguf_writer.add_key_length_swa(head_dim_swa)
|
||||
self.gguf_writer.add_value_length_swa(head_dim_swa)
|
||||
|
||||
expert_intermediate_size = self.find_hparam(["expert_intermediate_size", "moe_intermediate_size"])
|
||||
if expert_intermediate_size is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size)
|
||||
|
||||
# if use_double_wide_mlp is set, we need to adjust the value for kv shared layers
|
||||
use_double_wide_mlp = self.hparams.get("use_double_wide_mlp", False)
|
||||
first_kv_shared_layer_idx = self.block_count - num_kv_shared_layers
|
||||
if use_double_wide_mlp:
|
||||
n_ff = self.hparams["intermediate_size"]
|
||||
n_ff_arr = [n_ff if il < first_kv_shared_layer_idx else n_ff * 2 for il in range(self.block_count)]
|
||||
self.gguf_writer.add_feed_forward_length(n_ff_arr)
|
||||
|
||||
# handle num_global_key_value_heads
|
||||
num_key_value_heads_full = self.hparams.get("num_global_key_value_heads")
|
||||
num_key_value_heads_swa = self.hparams.get("num_key_value_heads")
|
||||
if num_key_value_heads_full is not None and num_key_value_heads_swa is not None:
|
||||
value_arr = [num_key_value_heads_swa if is_swa else num_key_value_heads_full for is_swa in swa_layers]
|
||||
self.gguf_writer.add_head_count_kv(value_arr)
|
||||
|
||||
# handle n_rot differently for global vs swa layers
|
||||
partial_rotary_factor_swa = self.hparams.get("partial_rotary_factor", 1.0)
|
||||
n_rot_full = int(head_dim_full) # "proportional" is used, see generate_extra_tensors
|
||||
n_rot_swa = int(head_dim_swa * partial_rotary_factor_swa)
|
||||
self.gguf_writer.add_rope_dimension_count(n_rot_full)
|
||||
self.gguf_writer.add_rope_dimension_count_swa(n_rot_swa)
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
# full layer uses "proportional" rope with partial_rotary_factor=0.25
|
||||
# the expected ordering is cc000000ss000000 (c = cos, s = sin, 0 = unrotated),
|
||||
# but ggml neox only supports ccss000000000000, and we cannot rearrange the head because that will break use_alternative_attention
|
||||
# solution is to set specific freq_factors for the unrotated dims
|
||||
|
||||
# IMPORTANT: this ROPE_FREQS tensor is ONLY used by the full_attention layers
|
||||
rope_params_full = self.hparams["rope_parameters"]["full_attention"]
|
||||
assert rope_params_full["rope_type"] == "proportional"
|
||||
head_dim_full = (self.hparams["global_head_dim"])
|
||||
partial_rotary_factor_full = rope_params_full["partial_rotary_factor"]
|
||||
n_rot_full = int(head_dim_full * partial_rotary_factor_full / 2)
|
||||
n_unrot_full = int(head_dim_full / 2) - n_rot_full
|
||||
values = [1.0] * n_rot_full + [1e30] * n_unrot_full
|
||||
rope_freqs_full = torch.tensor(values, dtype=torch.float32)
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), rope_freqs_full)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.endswith("per_dim_scale") or name.endswith("layer_scalar"):
|
||||
name = name + ".weight"
|
||||
|
||||
if "language_model." not in name and "rope_freqs" not in name:
|
||||
return # skip non-language model tensors
|
||||
|
||||
name = name.replace("language_model.", "")
|
||||
if name.endswith("router.scale"):
|
||||
name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_INP, bid, ".scale")
|
||||
yield (name, data_torch)
|
||||
return
|
||||
if ".per_expert_scale" in name:
|
||||
# convert per-expert scale to FFN down scale
|
||||
name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_EXP, bid, ".scale")
|
||||
yield (name, data_torch)
|
||||
return
|
||||
if ".experts." in name and not name.endswith(".weight"):
|
||||
name += ".weight"
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Gemma4ForConditionalGeneration")
|
||||
class Gemma4VisionAudioModel(MmprojModel):
|
||||
has_audio_encoder = True
|
||||
has_vision_encoder = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
self.hparams_vision["image_size"] = 224 # unused, but set to avoid error
|
||||
|
||||
# remap audio hparams
|
||||
if self.hparams_audio:
|
||||
self.hparams_audio["feat_in"] = self.hparams_audio.get("input_feat_size", 128)
|
||||
self.hparams_audio["intermediate_size"] = self.hparams_audio["hidden_size"] * 4
|
||||
else:
|
||||
self.has_audio_encoder = False
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# vision params
|
||||
self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.GEMMA4V)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6))
|
||||
|
||||
# audio params
|
||||
if self.hparams_audio:
|
||||
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.GEMMA4A)
|
||||
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["feat_in"])
|
||||
self.gguf_writer.add_audio_attention_layernorm_eps(1e-5)
|
||||
|
||||
def is_audio_tensor(self, name: str) -> bool:
|
||||
return "audio_tower" in name or "embed_audio" in name
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
if self.is_audio_tensor(name):
|
||||
if ".conv" in name or "_conv" in name and ".weight" in name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
if "position_embedding_table" in name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
if name.startswith("model.language_model."):
|
||||
return # skip
|
||||
|
||||
if len(data_torch.shape) == 0:
|
||||
# convert scalar tensors (input/output_mix/max) to 1D tensors
|
||||
data_torch = data_torch.unsqueeze(0)
|
||||
|
||||
if self.is_audio_tensor(name):
|
||||
assert self.hparams_audio is not None
|
||||
name = name.replace("model.audio_tower.", "conformer.")
|
||||
name = name.replace(".linear.", ".")
|
||||
if name.endswith("per_dim_key_scale") or name.endswith("per_dim_scale"):
|
||||
name = name + ".weight"
|
||||
data_torch = torch.nn.functional.softplus(data_torch)
|
||||
if "lconv1d.depthwise_conv1d" in name and name.endswith(".weight"):
|
||||
assert data_torch.shape[1] == 1
|
||||
data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[2])
|
||||
mapped_name = self.map_tensor_name(name, (".weight", ".bias", ".input_max", ".input_min", ".output_max", ".output_min"))
|
||||
yield (mapped_name, data_torch)
|
||||
|
||||
else:
|
||||
name = name.replace("model.vision_tower.encoder.", "vision_model.model.")
|
||||
name = name.replace(".linear.weight", ".weight")
|
||||
if name.endswith("layer_scalar") or name.endswith("position_embedding_table"):
|
||||
name = name + ".weight"
|
||||
if name.endswith("patch_embedder.input_proj.weight"):
|
||||
n_embd, ksize_sq_c = data_torch.shape
|
||||
patch_size = int((ksize_sq_c // 3) ** 0.5)
|
||||
data_torch = data_torch.reshape(n_embd, patch_size, patch_size, 3)
|
||||
data_torch = data_torch.permute(0, 3, 1, 2).contiguous()
|
||||
mapped_name = self.map_tensor_name(name, (".weight", ".bias", ".input_max", ".input_min", ".output_max", ".output_min"))
|
||||
yield (mapped_name, data_torch)
|
||||
|
||||
|
||||
@ModelBase.register("Starcoder2ForCausalLM")
|
||||
class StarCoder2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.STARCODER2
|
||||
@@ -11307,13 +11521,50 @@ class LLaDAMoEModel(TextModel):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("HunYuanDenseV1ForCausalLM")
|
||||
@ModelBase.register("HunYuanDenseV1ForCausalLM", "HunYuanVLForConditionalGeneration")
|
||||
class HunYuanModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE
|
||||
|
||||
def _get_eod_token_id(self) -> int | None:
|
||||
"""Get the actual end-of-generation token from config (eod_token_id)."""
|
||||
return self.hparams.get("eod_token_id")
|
||||
|
||||
def _get_eot_token_id(self) -> int | None:
|
||||
"""Get the end-of-turn token from generation_config.json.
|
||||
This is the first entry in eos_token_id when it's a list."""
|
||||
gen_cfg_path = self.dir_model / "generation_config.json"
|
||||
if gen_cfg_path.is_file():
|
||||
with open(gen_cfg_path, encoding="utf-8") as f:
|
||||
gen_cfg = json.load(f)
|
||||
eos = gen_cfg.get("eos_token_id")
|
||||
if isinstance(eos, list) and len(eos) >= 2:
|
||||
return eos[0]
|
||||
return None
|
||||
|
||||
def _fix_special_tokens(self):
|
||||
"""Fix EOS/EOT tokens that are incorrect in upstream configs."""
|
||||
eod_id = self._get_eod_token_id()
|
||||
if eod_id is not None:
|
||||
self.gguf_writer.add_eos_token_id(eod_id)
|
||||
eot_id = self._get_eot_token_id()
|
||||
if eot_id is not None:
|
||||
self.gguf_writer.add_eot_token_id(eot_id)
|
||||
|
||||
def set_vocab(self):
|
||||
if (self.dir_model / "tokenizer.json").is_file():
|
||||
self._set_vocab_gpt2()
|
||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
# HunyuanOCR has pad_token_id=-1 in config.json; exclude pad from SpecialVocab
|
||||
token_types = None
|
||||
if (self.hparams.get("pad_token_id") or 0) < 0:
|
||||
token_types = ('bos', 'eos', 'unk', 'sep', 'cls', 'mask')
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True, special_token_types=token_types)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
self._fix_special_tokens()
|
||||
else:
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
@@ -11365,13 +11616,18 @@ class HunYuanModel(TextModel):
|
||||
# FIX for BOS token: Overwrite incorrect id read from config.json
|
||||
if self.hparams['hidden_size'] == 4096:
|
||||
self.gguf_writer.add_bos_token_id(127958) # only for 7b dense, fix <|bos|> token
|
||||
self._fix_special_tokens()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
# HunyuanOCR has num_experts=1 which is not MoE, prevent parent from writing it
|
||||
saved_num_experts = self.hparams.pop("num_experts", None)
|
||||
super().set_gguf_parameters()
|
||||
if saved_num_experts is not None and saved_num_experts > 1:
|
||||
self.hparams["num_experts"] = saved_num_experts
|
||||
hparams = self.hparams
|
||||
|
||||
# Rope
|
||||
if self.rope_parameters.get("rope_type") == "dynamic":
|
||||
if self.rope_parameters.get("rope_type") in ("dynamic", "xdrope"):
|
||||
# HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
# 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
|
||||
alpha = self.rope_parameters.get("alpha", 50)
|
||||
@@ -11381,13 +11637,14 @@ class HunYuanModel(TextModel):
|
||||
self.gguf_writer.add_rope_freq_base(scaled_base)
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
self.gguf_writer.add_rope_scaling_factor(1)
|
||||
# There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length
|
||||
self.gguf_writer.add_context_length(256 * 1024) # 256k context length
|
||||
if self.rope_parameters.get("rope_type") == "dynamic":
|
||||
# There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length
|
||||
self.gguf_writer.add_context_length(256 * 1024) # 256k context length
|
||||
|
||||
# if any of our assumptions about the values are wrong, something has changed and this may need to be updated
|
||||
assert base == 10000.0 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \
|
||||
"HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
|
||||
# if any of our assumptions about the values are wrong, something has changed and this may need to be updated
|
||||
assert base == 10000.0 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \
|
||||
"HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name == "lm_head.weight":
|
||||
@@ -11395,9 +11652,48 @@ class HunYuanModel(TextModel):
|
||||
logger.info("Skipping tied output layer 'lm_head.weight'")
|
||||
return
|
||||
|
||||
# skip vision tensors for HunyuanVL models
|
||||
if name.startswith("vit."):
|
||||
return
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("HunYuanVLForConditionalGeneration")
|
||||
class HunyuanOCRVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
# HunyuanOCR uses max_image_size instead of image_size
|
||||
if "image_size" not in self.hparams_vision:
|
||||
self.hparams_vision["image_size"] = self.hparams_vision.get("max_image_size", 2048)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
assert self.hparams_vision is not None
|
||||
hparams = self.hparams_vision
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANOCR)
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-5))
|
||||
self.gguf_writer.add_vision_spatial_merge_size(hparams.get("spatial_merge_size", 2))
|
||||
self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"])
|
||||
self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if not name.startswith("vit."):
|
||||
return # skip text tensors
|
||||
# strip CLS token (row 0) from position embeddings so resize_position_embeddings works
|
||||
if "position_embedding" in name:
|
||||
data_torch = data_torch[1:] # [n_patches+1, n_embd] -> [n_patches, n_embd]
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
# force conv weights to F32 or F16 to avoid BF16 IM2COL issues on Metal
|
||||
if ("mm.0." in new_name or "mm.2." in new_name) and new_name.endswith(".weight"):
|
||||
return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
|
||||
@ModelBase.register("SmolLM3ForCausalLM")
|
||||
class SmolLM3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.SMOLLM3
|
||||
@@ -11522,10 +11818,8 @@ class LFM2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.LFM2
|
||||
|
||||
def _add_feed_forward_length(self):
|
||||
ff_dim = self.hparams["block_ff_dim"]
|
||||
|
||||
ff_dim = self.find_hparam(["block_ff_dim", "intermediate_size"])
|
||||
auto_adjust_ff_dim = self.hparams["block_auto_adjust_ff_dim"]
|
||||
ff_dim = self.hparams["block_ff_dim"]
|
||||
ffn_dim_multiplier = self.hparams["block_ffn_dim_multiplier"]
|
||||
multiple_of = self.hparams["block_multiple_of"]
|
||||
|
||||
|
||||
@@ -57,13 +57,14 @@ ZenDNN is optimized for AMD EPYC™ processors and AMD Ryzen™ processors based
|
||||
|
||||
## Supported Operations
|
||||
|
||||
The ZenDNN backend currently accelerates **matrix multiplication (MUL_MAT)** operations only. Other operations are handled by the standard CPU backend.
|
||||
The ZenDNN backend accelerates **matrix multiplication (MUL_MAT)** and **expert-based matrix multiplication (MUL_MAT_ID)** operations. Other operations are handled by the standard CPU backend.
|
||||
|
||||
| Operation | Status | Notes |
|
||||
|:-------------|:-------:|:----------------------------------------------:|
|
||||
| MUL_MAT | Support | Accelerated via ZenDNN LowOHA MatMul |
|
||||
| MUL_MAT_ID | Support | Accelerated via ZenDNN LowOHA MatMul (MoE) |
|
||||
|
||||
*Note:* Since only MUL_MAT is accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs).
|
||||
*Note:* Since MUL_MAT and MUL_MAT_ID are accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs and Mixture-of-Experts models).
|
||||
|
||||
## DataType Supports
|
||||
|
||||
@@ -181,7 +182,7 @@ For detailed profiling and logging options, refer to the [ZenDNN Logging Documen
|
||||
|
||||
## Known Issues
|
||||
|
||||
- **Limited operation support**: Currently only matrix multiplication (MUL_MAT) is accelerated via ZenDNN. Other operations fall back to the standard CPU backend.
|
||||
- **Limited operation support**: Currently matrix multiplication (MUL_MAT) and expert-based matrix multiplication (MUL_MAT_ID) are accelerated via ZenDNN. Other operations fall back to the standard CPU backend. Future updates may expand supported operations.
|
||||
- **BF16 support**: BF16 operations require AMD Zen 4 or Zen 5 architecture (EPYC 9004/9005 series). On older CPUs, operations will use FP32.
|
||||
- **NUMA awareness**: For multi-socket systems, manual NUMA binding may be required for optimal performance.
|
||||
|
||||
@@ -216,4 +217,4 @@ Please add the **[ZenDNN]** prefix/tag in issues/PRs titles to help the ZenDNN-t
|
||||
|
||||
## TODO
|
||||
|
||||
- Expand operation support beyond MUL_MAT (attention operations, activations, etc.)
|
||||
- Expand operation support beyond MUL_MAT and MUL_MAT_ID (attention operations, activations, etc.)
|
||||
|
||||
@@ -389,7 +389,7 @@ You can download it from your Linux distro's package manager or from here: [ROCm
|
||||
|
||||
|
||||
The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used.
|
||||
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3.
|
||||
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3. Note that [`HSA_OVERRIDE_GFX_VERSION`] is [not supported on Windows](https://github.com/ROCm/ROCm/issues/2654)
|
||||
|
||||
### Unified Memory
|
||||
|
||||
@@ -728,7 +728,7 @@ To read documentation for how to build on Android, [click here](./android.md)
|
||||
|
||||
## WebGPU [In Progress]
|
||||
|
||||
The WebGPU backend relies on [Dawn](https://dawn.googlesource.com/dawn). Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/quickstart-cmake.md) to install Dawn locally so that llama.cpp can find it using CMake. The current implementation is up-to-date with Dawn commit `bed1a61`.
|
||||
The WebGPU backend relies on [Dawn](https://dawn.googlesource.com/dawn). Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/quickstart-cmake.md) to install Dawn locally so that llama.cpp can find it using CMake. The current implementation is up-to-date with Dawn commit `18eb229`.
|
||||
|
||||
In the llama.cpp directory, build with CMake:
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload
|
||||
> - PaddleOCR-VL: https://github.com/ggml-org/llama.cpp/pull/18825
|
||||
> - GLM-OCR: https://github.com/ggml-org/llama.cpp/pull/19677
|
||||
> - Deepseek-OCR: https://github.com/ggml-org/llama.cpp/pull/17400
|
||||
> - HunyuanOCR: https://github.com/ggml-org/llama.cpp/pull/21395
|
||||
|
||||
## Pre-quantized models
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ Legend:
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | 🟡 | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
|
||||
9986
docs/ops/ZenDNN.csv
9986
docs/ops/ZenDNN.csv
File diff suppressed because it is too large
Load Diff
@@ -15,13 +15,18 @@ static bool run(llama_context * ctx, const common_params & params) {
|
||||
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos, true);
|
||||
|
||||
if (tokens.empty()) {
|
||||
LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
LOG_INF("number of input tokens = %zu\n", tokens.size());
|
||||
for (size_t i = 0; i < tokens.size(); ++i) {
|
||||
LOG_INF(" %d\n", tokens[i]);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
|
||||
@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 9)
|
||||
set(GGML_VERSION_PATCH 9)
|
||||
set(GGML_VERSION_PATCH 11)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||
@@ -166,15 +166,16 @@ if (NOT MSVC)
|
||||
option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF)
|
||||
option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF)
|
||||
endif()
|
||||
option(GGML_LASX "ggml: enable lasx" ON)
|
||||
option(GGML_LSX "ggml: enable lsx" ON)
|
||||
option(GGML_RVV "ggml: enable rvv" ON)
|
||||
option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
|
||||
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
|
||||
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
|
||||
option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause " ON)
|
||||
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
|
||||
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
|
||||
option(GGML_LASX "ggml: enable lasx" ON)
|
||||
option(GGML_LSX "ggml: enable lsx" ON)
|
||||
option(GGML_RVV "ggml: enable rvv" ON)
|
||||
option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
|
||||
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
|
||||
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
|
||||
option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause" ON)
|
||||
option(GGML_RV_ZVFBFWMA "ggml: enable riscv zvfbfwma" OFF)
|
||||
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
|
||||
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
|
||||
|
||||
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
|
||||
@@ -2350,11 +2350,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
const int64_t n_heads = node->src[1]->ne[1];
|
||||
n_tasks = MIN(n_threads, n_heads);
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
case GGML_OP_WIN_UNPART:
|
||||
|
||||
@@ -180,44 +180,49 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_zvfh)
|
||||
template <>
|
||||
inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
||||
}
|
||||
inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
||||
}
|
||||
inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
||||
}
|
||||
inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
||||
}
|
||||
inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
template <> inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
|
||||
return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
||||
}
|
||||
inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
|
||||
template <> inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
|
||||
return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
||||
}
|
||||
inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
|
||||
template <> inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
|
||||
return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
||||
}
|
||||
inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
|
||||
template <> inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
|
||||
return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_zvfh)
|
||||
template <> inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
||||
}
|
||||
template <> inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
||||
}
|
||||
template <> inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
||||
}
|
||||
template <> inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_zvfbfwma)
|
||||
inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
|
||||
template <> inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
|
||||
return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
||||
}
|
||||
inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
|
||||
template <> inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
|
||||
return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
||||
}
|
||||
inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
|
||||
template <> inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
|
||||
return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
||||
}
|
||||
template <> inline vfloat32m8_t madd(vbfloat16m4_t a, vbfloat16m4_t b, vfloat32m8_t c) {
|
||||
return __riscv_vfwmaccbf16_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -272,7 +277,7 @@ inline float hsum(__m512 x) {
|
||||
}
|
||||
#endif // __AVX512F__
|
||||
|
||||
#if defined(__riscv_zvfh)
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
inline float hsum(vfloat32m1_t x) {
|
||||
return __riscv_vfmv_f_s_f32m1_f32(
|
||||
__riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
|
||||
@@ -379,19 +384,7 @@ template <> inline __m256bh load(const float *p) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_zvfh)
|
||||
template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
|
||||
}
|
||||
template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
|
||||
}
|
||||
template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
|
||||
}
|
||||
template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
|
||||
}
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
template <> inline vfloat32m1_t load(const float *p) {
|
||||
return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
|
||||
}
|
||||
@@ -406,6 +399,21 @@ template <> inline vfloat32m8_t load(const float *p) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_zvfh)
|
||||
template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
|
||||
}
|
||||
template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
|
||||
}
|
||||
template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
|
||||
}
|
||||
template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_zvfbfwma)
|
||||
template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
|
||||
return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
|
||||
@@ -416,23 +424,14 @@ template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
|
||||
template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
|
||||
return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
|
||||
}
|
||||
template <> inline vbfloat16m4_t load(const ggml_bf16_t *p) {
|
||||
return __riscv_vle16_v_bf16m4(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m4());
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_zvfh)
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
template <typename T> T set_zero();
|
||||
|
||||
template <> inline vfloat16mf2_t set_zero() {
|
||||
return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
|
||||
}
|
||||
template <> inline vfloat16m1_t set_zero() {
|
||||
return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
|
||||
}
|
||||
template <> inline vfloat16m2_t set_zero() {
|
||||
return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
|
||||
}
|
||||
template <> inline vfloat16m4_t set_zero() {
|
||||
return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
|
||||
}
|
||||
template <> inline vfloat32m1_t set_zero() {
|
||||
return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
|
||||
}
|
||||
@@ -449,14 +448,22 @@ template <> inline vfloat32m8_t set_zero() {
|
||||
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
template <typename T> size_t vlmax() {
|
||||
if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
|
||||
if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); }
|
||||
#if defined (__riscv_zvfh)
|
||||
else if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
|
||||
#endif
|
||||
#if defined (__riscv_zvfbfwma)
|
||||
else if constexpr (std::is_same_v<T, vbfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
|
||||
else if constexpr (std::is_same_v<T, vbfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
|
||||
else if constexpr (std::is_same_v<T, vbfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
|
||||
else if constexpr (std::is_same_v<T, vbfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
@@ -3740,7 +3747,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
||||
params->ith, params->nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#elif defined(__riscv_zvfh)
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
#if LMUL == 1
|
||||
tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
|
||||
k, (const float *)A, lda,
|
||||
@@ -3804,23 +3811,25 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
||||
return true;
|
||||
}
|
||||
#elif defined(__riscv_zvfbfwma)
|
||||
#if LMUL == 1
|
||||
tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
||||
k, (const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
#elif LMUL == 2
|
||||
tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
||||
k, (const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
#else // LMUL = 4
|
||||
tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
||||
k, (const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
#endif
|
||||
return tb.matmul(m, n);
|
||||
if (Btype == GGML_TYPE_BF16) {
|
||||
#if LMUL == 1
|
||||
tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
||||
k, (const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
#elif LMUL == 2
|
||||
tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
||||
k, (const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
#else // LMUL = 4
|
||||
tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
||||
k, (const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
#endif
|
||||
return tb.matmul(m, n);
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -9953,13 +9953,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (ith >= HEADS) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int h_start = (HEADS * ith) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
const int h_start = (HEADS * (ith )) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
|
||||
float * k = (float *) dst->src[0]->data;
|
||||
float * v = (float *) dst->src[1]->data;
|
||||
@@ -10170,13 +10166,9 @@ static void ggml_compute_forward_gla_f32(
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (ith >= HEADS) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int h_start = (HEADS * ith) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
const int h_start = (HEADS * (ith )) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
|
||||
float * k = (float *) dst->src[0]->data;
|
||||
float * v = (float *) dst->src[1]->data;
|
||||
@@ -10633,13 +10625,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (ith >= HEADS) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int h_start = (HEADS * ith) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
const int h_start = (HEADS * (ith )) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
|
||||
float * r = (float *) dst->src[0]->data;
|
||||
float * w = (float *) dst->src[1]->data;
|
||||
|
||||
@@ -126,7 +126,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
||||
const int ggml_f16_epr = sve_register_length / 16; // running when 16
|
||||
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
|
||||
|
||||
const int np = (n & ~(ggml_f16_step - 1));
|
||||
int np = (n & ~(ggml_f16_step - 1));
|
||||
|
||||
svfloat16_t sum_00 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum_01 = svdup_n_f16(0.0f);
|
||||
@@ -224,71 +224,75 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
||||
}
|
||||
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
|
||||
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
|
||||
np = n;
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
#if defined(__riscv_zvfh)
|
||||
size_t vl = __riscv_vsetvlmax_e32m4();
|
||||
|
||||
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
|
||||
size_t vl = __riscv_vsetvlmax_e32m4();
|
||||
// initialize accumulators to all zeroes
|
||||
vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
|
||||
// initialize accumulators to all zeroes
|
||||
vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
// calculate step size
|
||||
const size_t epr = __riscv_vsetvlmax_e16m2();
|
||||
const size_t step = epr * 2;
|
||||
int np = (n & ~(step - 1));
|
||||
|
||||
// calculate step size
|
||||
const size_t epr = __riscv_vsetvlmax_e16m2();
|
||||
const size_t step = epr * 2;
|
||||
const int np = (n & ~(step - 1));
|
||||
// unroll by 2 along the row dimension
|
||||
for (int i = 0; i < np; i += step) {
|
||||
vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
|
||||
vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
|
||||
vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
|
||||
vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
|
||||
vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
|
||||
|
||||
// unroll by 2 along the row dimension
|
||||
for (int i = 0; i < np; i += step) {
|
||||
vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
|
||||
vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
|
||||
vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
|
||||
vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
|
||||
vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
|
||||
vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
|
||||
vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
|
||||
vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
|
||||
vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
|
||||
vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
|
||||
}
|
||||
|
||||
vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
|
||||
vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
|
||||
vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
|
||||
vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
|
||||
vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
|
||||
}
|
||||
vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
|
||||
vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
|
||||
|
||||
vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
|
||||
vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
|
||||
// leftovers
|
||||
for (int i = np; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m2(n - i);
|
||||
vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
|
||||
vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
|
||||
vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m2(n - i);
|
||||
vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
|
||||
vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
|
||||
vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
|
||||
vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
|
||||
vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
|
||||
}
|
||||
|
||||
vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
|
||||
vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
|
||||
}
|
||||
|
||||
// reduce
|
||||
vl = __riscv_vsetvlmax_e32m2();
|
||||
vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
|
||||
__riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
|
||||
vl = __riscv_vsetvlmax_e32m1();
|
||||
vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
|
||||
__riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
|
||||
vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
|
||||
acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||
|
||||
vl = __riscv_vsetvlmax_e32m2();
|
||||
vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
|
||||
__riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
|
||||
vl = __riscv_vsetvlmax_e32m1();
|
||||
vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
|
||||
__riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
|
||||
vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
|
||||
acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||
sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
|
||||
sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
|
||||
// reduce
|
||||
vl = __riscv_vsetvlmax_e32m2();
|
||||
vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
|
||||
__riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
|
||||
vl = __riscv_vsetvlmax_e32m1();
|
||||
vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
|
||||
__riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
|
||||
vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
|
||||
acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||
|
||||
vl = __riscv_vsetvlmax_e32m2();
|
||||
vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
|
||||
__riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
|
||||
vl = __riscv_vsetvlmax_e32m1();
|
||||
vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
|
||||
__riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
|
||||
vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
|
||||
acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||
sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
|
||||
sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
|
||||
np = n;
|
||||
#else
|
||||
const int np = 0;
|
||||
#endif
|
||||
#else
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
@@ -313,21 +317,17 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
||||
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
|
||||
GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
||||
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
for (int i = 0; i < n; ++i) {
|
||||
// scalar path
|
||||
const int np = 0;
|
||||
#endif
|
||||
// scalar and leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
||||
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
|
||||
s[i] = (float)sumf[i];
|
||||
@@ -532,40 +532,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
|
||||
svst1_f16(pg, (__fp16 *)(y + np2), hy);
|
||||
}
|
||||
np = n;
|
||||
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
|
||||
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
|
||||
const _Float16 scale = *(const _Float16*)(&s);
|
||||
#elif defined(__riscv_v_intrinsic) // implies __riscv_v_intrinsic
|
||||
#if defined (__riscv_zvfh)
|
||||
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
|
||||
const _Float16 scale = *(const _Float16*)(&s);
|
||||
|
||||
// calculate step size
|
||||
const int epr = __riscv_vsetvlmax_e16m4();
|
||||
const int step = epr * 2;
|
||||
int np = (n & ~(step - 1));
|
||||
// calculate step size
|
||||
const int epr = __riscv_vsetvlmax_e16m4();
|
||||
const int step = epr * 2;
|
||||
int np = (n & ~(step - 1));
|
||||
|
||||
// unroll by 2
|
||||
for (int i = 0; i < np; i += step) {
|
||||
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
|
||||
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
// unroll by 2
|
||||
for (int i = 0; i < np; i += step) {
|
||||
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
|
||||
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
|
||||
vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
|
||||
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
|
||||
ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
}
|
||||
vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
|
||||
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
|
||||
ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
}
|
||||
|
||||
// leftovers
|
||||
int vl;
|
||||
for (int i = np; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m4(n - i);
|
||||
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
|
||||
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
|
||||
}
|
||||
np = n;
|
||||
// leftovers
|
||||
int vl;
|
||||
for (int i = np; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m4(n - i);
|
||||
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
|
||||
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
|
||||
}
|
||||
np = n;
|
||||
#else
|
||||
// fall to scalar path
|
||||
const int np = 0;
|
||||
#endif
|
||||
#elif defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
@@ -584,10 +589,11 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
|
||||
}
|
||||
}
|
||||
#else
|
||||
// scalar path
|
||||
const int np = 0;
|
||||
#endif
|
||||
|
||||
// leftovers
|
||||
// scalar and leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
|
||||
}
|
||||
@@ -785,7 +791,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
|
||||
const int ggml_f16_step = 2 * ggml_f16_epr;
|
||||
|
||||
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
||||
const int np = (n & ~(ggml_f16_step - 1));
|
||||
int np = (n & ~(ggml_f16_step - 1));
|
||||
svfloat16_t ay1, ay2;
|
||||
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
@@ -805,36 +811,43 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
|
||||
svfloat16_t out = svmul_f16_m(pg, hy, vx);
|
||||
svst1_f16(pg, (__fp16 *)(y + np), out);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
|
||||
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
|
||||
const _Float16 scale = *(const _Float16*)(&s);
|
||||
np = n;
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
#if defined(__riscv_zvfh)
|
||||
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
|
||||
const _Float16 scale = *(const _Float16*)(&s);
|
||||
|
||||
// calculate step size
|
||||
const int epr = __riscv_vsetvlmax_e16m4();
|
||||
const int step = epr * 2;
|
||||
const int np = (n & ~(step - 1));
|
||||
// calculate step size
|
||||
const int epr = __riscv_vsetvlmax_e16m4();
|
||||
const int step = epr * 2;
|
||||
int np = (n & ~(step - 1));
|
||||
|
||||
// unroll by 2
|
||||
for (int i = 0; i < np; i += step) {
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
|
||||
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
// unroll by 2
|
||||
for (int i = 0; i < np; i += step) {
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
|
||||
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
|
||||
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
|
||||
ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
}
|
||||
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
|
||||
ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
}
|
||||
|
||||
// leftovers
|
||||
int vl;
|
||||
for (int i = np; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m4(n - i);
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
|
||||
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
|
||||
}
|
||||
// leftovers
|
||||
int vl;
|
||||
for (int i = np; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m4(n - i);
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
|
||||
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
|
||||
}
|
||||
np = n;
|
||||
#else
|
||||
// fall to scalar path
|
||||
const int np = 0;
|
||||
#endif
|
||||
#elif defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
@@ -850,17 +863,14 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
#else
|
||||
// scalar path
|
||||
const int np = 0;
|
||||
#endif
|
||||
// scalar and leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
|
||||
|
||||
@@ -800,19 +800,32 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
|
||||
#ifdef FP8_AVAILABLE
|
||||
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
|
||||
#if defined(GGML_USE_HIP) && defined(CDNA3)
|
||||
// ROCm dose not support fp8 in software on devices with fp8 hardware,
|
||||
#if defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
|
||||
// ROCm does not support fp8 in software on devices with fp8 hardware,
|
||||
// but CDNA3 supports only e4m3_fnuz (no inf).
|
||||
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
|
||||
const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast<const __hip_fp8_e4m3_fnuz *>(&bits);
|
||||
#else
|
||||
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
|
||||
#endif // defined(GGML_USE_HIP) && defined(GGML_USE_HIP)
|
||||
return static_cast<float>(xf) / 2;
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP8_AVAILABLE
|
||||
#if defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP)
|
||||
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
|
||||
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
|
||||
return static_cast<float>(xf) / 2;
|
||||
#else
|
||||
if (x == 0 || (x == 0x7F && x != 0xFF)) { // Convert NaN to 0.0f
|
||||
return 0.0f;
|
||||
}
|
||||
const int exp = (x >> 3) & 0xF;
|
||||
const int man = x & 0x7;
|
||||
float raw;
|
||||
if (exp == 0) {
|
||||
raw = ldexpf((float) man, -9);
|
||||
} else {
|
||||
raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7);
|
||||
}
|
||||
return static_cast<float>(raw / 2);
|
||||
#endif // defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP)
|
||||
#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
|
||||
|
||||
@@ -676,9 +676,96 @@ static __global__ void flash_attn_mask_to_KV_max(
|
||||
|
||||
template<int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
|
||||
const int ne11, const int ne12, const int nbatch_fa) {
|
||||
static __global__ void flash_attn_stream_k_fixup_uniform(
|
||||
float * __restrict__ dst,
|
||||
const float2 * __restrict__ dst_fixup,
|
||||
const int ne01, const int ne02,
|
||||
const int ne12, const int nblocks_stream_k,
|
||||
const int gqa_ratio,
|
||||
const int blocks_per_tile,
|
||||
const uint3 fd_iter_j_z_ne12,
|
||||
const uint3 fd_iter_j_z,
|
||||
const uint3 fd_iter_j) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int tile_idx = blockIdx.x; // One block per output tile.
|
||||
const int j = blockIdx.y;
|
||||
const int c = blockIdx.z;
|
||||
const int jc = j*ncols2 + c;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks.
|
||||
const int b_first = tile_idx * blocks_per_tile;
|
||||
const int b_last = b_first + blocks_per_tile - 1;
|
||||
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2*2*ncols);
|
||||
|
||||
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
|
||||
const uint2 dm0 = fast_div_modulo(tile_idx, fd_iter_j_z_ne12);
|
||||
const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_j_z);
|
||||
const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_j);
|
||||
|
||||
const int sequence = dm0.x;
|
||||
const int z_KV = dm1.x;
|
||||
const int zt_gqa = dm2.x;
|
||||
const int jt = dm2.y;
|
||||
|
||||
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
||||
|
||||
if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
|
||||
|
||||
// Load the partial result that needs a fixup
|
||||
float dst_val = *dst;
|
||||
float max_val;
|
||||
float rowsum;
|
||||
{
|
||||
const float2 tmp = dst_fixup[b_last*ncols + jc];
|
||||
max_val = tmp.x;
|
||||
rowsum = tmp.y;
|
||||
}
|
||||
|
||||
// Combine with all previous blocks in this tile.
|
||||
for (int bidx = b_last - 1; bidx >= b_first; --bidx) {
|
||||
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
|
||||
|
||||
const float2 tmp = dst_fixup[(nblocks_stream_k + bidx)*ncols + jc];
|
||||
|
||||
const float max_val_new = fmaxf(max_val, tmp.x);
|
||||
|
||||
const float diff_val = max_val - max_val_new;
|
||||
const float diff_add = tmp.x - max_val_new;
|
||||
|
||||
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
|
||||
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
|
||||
|
||||
dst_val = scale_val*dst_val + scale_add*dst_add;
|
||||
rowsum = scale_val*rowsum + scale_add*tmp.y;
|
||||
|
||||
max_val = max_val_new;
|
||||
}
|
||||
|
||||
// Write back final result:
|
||||
*dst = dst_val / rowsum;
|
||||
}
|
||||
|
||||
// General fixup kernel for the case where the number of blocks per tile is not uniform across tiles
|
||||
// (blocks_num.x not a multiple of ntiles_dst)
|
||||
template <int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup_general(
|
||||
float * __restrict__ dst,
|
||||
const float2 * __restrict__ dst_fixup,
|
||||
const int ne01, const int ne02,
|
||||
const int gqa_ratio,
|
||||
const int total_work,
|
||||
const uint3 fd_iter_k_j_z_ne12,
|
||||
const uint3 fd_iter_k_j_z,
|
||||
const uint3 fd_iter_k_j,
|
||||
const uint3 fd_iter_k) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
@@ -689,27 +776,26 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
|
||||
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
|
||||
|
||||
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
||||
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
||||
const int kbc0 = int64_t(bidx0 + 0)*total_work / gridDim.x;
|
||||
const int kbc0_stop = int64_t(bidx0 + 1)*total_work / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
|
||||
const bool wrote_beginning_of_tile = fastmodulo(kbc0, fd_iter_k) == 0;
|
||||
const bool did_not_write_last = fastdiv(kbc0, fd_iter_k) == fastdiv(kbc0_stop, fd_iter_k) && fastmodulo(kbc0_stop, fd_iter_k) != 0;
|
||||
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
|
||||
return;
|
||||
}
|
||||
|
||||
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
|
||||
const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
|
||||
const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
|
||||
const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
|
||||
const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
||||
const uint2 dm0 = fast_div_modulo(kbc0, fd_iter_k_j_z_ne12);
|
||||
const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_k_j_z);
|
||||
const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_k_j);
|
||||
const uint2 dm3 = fast_div_modulo(dm2.y, fd_iter_k);
|
||||
|
||||
const int sequence = dm0.x;
|
||||
const int z_KV = dm1.x;
|
||||
const int zt_gqa = dm2.x;
|
||||
const int jt = dm3.x;
|
||||
|
||||
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
||||
|
||||
@@ -733,10 +819,11 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
|
||||
// Iterate over previous blocks and compute the combined results.
|
||||
// All CUDA blocks that get here must have a previous block that needs a fixup.
|
||||
const int tile_kbc0 = fastdiv(kbc0, fd_iter_k);
|
||||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
||||
const int kbc = int64_t(bidx)*total_work / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
@@ -762,7 +849,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
max_val = max_val_new;
|
||||
|
||||
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
||||
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
|
||||
if (fastmodulo(kbc, fd_iter_k) == 0 || fastdiv(kbc, fd_iter_k) < tile_kbc0) {
|
||||
break;
|
||||
}
|
||||
bidx--;
|
||||
@@ -976,14 +1063,28 @@ void launch_fattn(
|
||||
const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks;
|
||||
const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);
|
||||
|
||||
const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst);
|
||||
|
||||
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
|
||||
|
||||
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst;
|
||||
blocks_num.x = ntiles_dst;
|
||||
blocks_num.y = 1;
|
||||
blocks_num.z = 1;
|
||||
|
||||
if(use_stream_k) {
|
||||
const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst);
|
||||
// Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks (avoids fixup).
|
||||
// Only do this if the occupancy loss from rounding is acceptable.
|
||||
const int nblocks_stream_k_rounded = (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst;
|
||||
const int max_efficiency_loss_percent = 5;
|
||||
const int efficiency_loss_percent = nblocks_stream_k_rounded > 0
|
||||
? 100 * (nblocks_stream_k_raw - nblocks_stream_k_rounded) / nblocks_stream_k_raw
|
||||
: 100;
|
||||
const int nblocks_stream_k = efficiency_loss_percent <= max_efficiency_loss_percent
|
||||
? nblocks_stream_k_rounded
|
||||
: nblocks_stream_k_raw;
|
||||
|
||||
blocks_num.x = nblocks_stream_k;
|
||||
}
|
||||
|
||||
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
|
||||
}
|
||||
@@ -1063,13 +1164,40 @@ void launch_fattn(
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if (stream_k) {
|
||||
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
if ((int)blocks_num.x % ntiles_dst == 0 && (int)blocks_num.x > ntiles_dst) {
|
||||
// Optimized fixup: nblocks_stream_k is a multiple of ntiles_dst, launch one block per tile.
|
||||
const int nblocks_sk = (int)blocks_num.x;
|
||||
const int bpt = nblocks_sk / ntiles_dst;
|
||||
|
||||
const uint3 fd0 = init_fastdiv_values(ntiles_x * ntiles_z_gqa * K->ne[2]);
|
||||
const uint3 fd1 = init_fastdiv_values(ntiles_x * ntiles_z_gqa);
|
||||
const uint3 fd2 = init_fastdiv_values(ntiles_x);
|
||||
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2};
|
||||
|
||||
flash_attn_stream_k_fixup_uniform<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr,
|
||||
Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk,
|
||||
gqa_ratio, bpt, fd0, fd1, fd2);
|
||||
} else if (ntiles_dst % blocks_num.x != 0) {
|
||||
// General fixup for the cases where nblocks_stream_k < ntiles_dst.
|
||||
const int total_work = ntiles_KV * ntiles_dst;
|
||||
|
||||
const uint3 fd_k_j_z_ne12 = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa * K->ne[2]);
|
||||
const uint3 fd_k_j_z = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa);
|
||||
const uint3 fd_k_j = init_fastdiv_values(ntiles_KV * ntiles_x);
|
||||
const uint3 fd_k = init_fastdiv_values(ntiles_KV);
|
||||
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
||||
|
||||
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
||||
flash_attn_stream_k_fixup_general<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
|
||||
((float *) KQV->data, dst_tmp_meta.ptr,
|
||||
Q->ne[1], Q->ne[2], gqa_ratio, total_work,
|
||||
fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k);
|
||||
}
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
|
||||
@@ -66,6 +66,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
@@ -80,6 +85,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
@@ -89,6 +99,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 64, 1, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
||||
@@ -103,6 +118,10 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -1552,7 +1571,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
||||
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
@@ -1815,6 +1834,15 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8);
|
||||
|
||||
// The number of viable configurations for Deepseek is very limited:
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
|
||||
@@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
||||
} break;
|
||||
case 512: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst);
|
||||
} break;
|
||||
case 576: {
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
|
||||
|
||||
@@ -68,6 +68,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
@@ -124,6 +128,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
@@ -187,6 +195,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
@@ -251,6 +264,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
@@ -767,7 +785,7 @@ static __global__ void flash_attn_tile(
|
||||
#ifdef GGML_USE_WMMA_FATTN
|
||||
(ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
|
||||
#endif // GGML_USE_WMMA_FATTN
|
||||
(use_logit_softcap && !(DV == 128 || DV == 256))
|
||||
(use_logit_softcap && !(DV == 128 || DV == 256 || DV == 512))
|
||||
) {
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
@@ -1192,7 +1210,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
if constexpr (DV == 512) {
|
||||
if constexpr (DKQ == 576) {
|
||||
if (use_gqa_opt && gqa_ratio % 16 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
@@ -1203,7 +1221,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
if constexpr (DKQ <= 512) {
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
@@ -1214,13 +1232,15 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
|
||||
if constexpr (DV <= 256) {
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@@ -1255,4 +1275,5 @@ extern DECL_FATTN_TILE_CASE( 96, 96);
|
||||
extern DECL_FATTN_TILE_CASE(112, 112);
|
||||
extern DECL_FATTN_TILE_CASE(128, 128);
|
||||
extern DECL_FATTN_TILE_CASE(256, 256);
|
||||
extern DECL_FATTN_TILE_CASE(512, 512);
|
||||
extern DECL_FATTN_TILE_CASE(576, 512);
|
||||
|
||||
@@ -135,6 +135,10 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
GGML_ASSERT(V->ne[0] == 256);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 512:
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst);
|
||||
break;
|
||||
case 576: {
|
||||
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
@@ -340,6 +344,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 512:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 576:
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
@@ -424,7 +436,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
}
|
||||
|
||||
// Use the WMMA kernel if possible:
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) {
|
||||
if (can_use_vector_kernel && Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
@@ -457,7 +469,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
}
|
||||
|
||||
// Use MFMA flash attention for CDNA (MI100+):
|
||||
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) {
|
||||
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) {
|
||||
const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
|
||||
// MMA vs tile crossover benchmarked on MI300X @ d32768:
|
||||
// hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%)
|
||||
|
||||
@@ -4791,9 +4791,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_MXFP4:
|
||||
#ifdef FP8_AVAILABLE
|
||||
case GGML_TYPE_NVFP4:
|
||||
#endif // FP8_AVAILABLE
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
|
||||
@@ -23,6 +23,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
|
||||
case GGML_TYPE_MXFP4:
|
||||
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_NVFP4:
|
||||
mul_mat_q_case<GGML_TYPE_NVFP4>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
|
||||
break;
|
||||
@@ -273,6 +276,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
@@ -362,5 +366,4 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
|
||||
}
|
||||
|
||||
return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
|
||||
}
|
||||
|
||||
@@ -68,6 +68,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
||||
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||
case GGML_TYPE_MXFP4:
|
||||
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||
case GGML_TYPE_NVFP4:
|
||||
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||
case GGML_TYPE_Q2_K:
|
||||
return MMQ_Q8_1_DS_LAYOUT_D2S6;
|
||||
case GGML_TYPE_Q3_K:
|
||||
@@ -189,6 +191,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
||||
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
|
||||
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
|
||||
case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16;
|
||||
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
|
||||
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
|
||||
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
|
||||
@@ -206,12 +209,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
||||
}
|
||||
}
|
||||
|
||||
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
|
||||
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4
|
||||
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4
|
||||
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
|
||||
|
||||
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
|
||||
@@ -220,6 +224,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
|
||||
static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
|
||||
|
||||
|
||||
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
switch (type) {
|
||||
@@ -230,6 +236,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
// tile sizes are the same for Q8_1 and FP4 for blackwell
|
||||
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
|
||||
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
||||
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
@@ -826,6 +833,65 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <int mmq_y, bool need_check>
|
||||
static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
|
||||
int * __restrict__ x_tile,
|
||||
const int kb0,
|
||||
const int i_max,
|
||||
const int stride) {
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
|
||||
constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4;
|
||||
constexpr int rows_per_warp = warp_size / threads_per_row;
|
||||
const int kbx = threadIdx.x % threads_per_row;
|
||||
const int row_in_warp = threadIdx.x / threads_per_row;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
|
||||
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
|
||||
|
||||
if constexpr (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx;
|
||||
const uint32_t * __restrict__ src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
|
||||
const int kqs = 16 * kbx;
|
||||
const int ksc = 4 * kbx;
|
||||
|
||||
#pragma unroll
|
||||
for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
|
||||
const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4);
|
||||
const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4);
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x;
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x;
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y;
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y;
|
||||
x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
|
||||
#else
|
||||
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x;
|
||||
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x;
|
||||
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y;
|
||||
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y;
|
||||
x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||
@@ -1229,7 +1295,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
// Used for Q3_K, IQ2_S, and IQ2_XS
|
||||
// Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||
@@ -3261,6 +3327,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
|
||||
static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
|
||||
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
||||
@@ -4069,6 +4143,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_NVFP4);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|
||||
|
||||
@@ -235,30 +235,33 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type
|
||||
// Host function: returns the max batch size for the current arch+type at runtime.
|
||||
int get_mmvq_mmid_max_batch(ggml_type type, int cc) {
|
||||
// NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID.
|
||||
if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||
return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
if (cc >= GGML_CUDA_CC_TURING) {
|
||||
return get_mmvq_mmid_max_batch_turing_plus(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||
if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||
return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
if (cc >= GGML_CUDA_CC_TURING) {
|
||||
return get_mmvq_mmid_max_batch_turing_plus(type);
|
||||
}
|
||||
return get_mmvq_mmid_max_batch_pascal_older(type);
|
||||
}
|
||||
|
||||
// AMD
|
||||
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna4(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna3(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return get_mmvq_mmid_max_batch_cdna(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_GCN(cc)) {
|
||||
return get_mmvq_mmid_max_batch_gcn(type);
|
||||
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna4(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna3(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return get_mmvq_mmid_max_batch_cdna(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_GCN(cc)) {
|
||||
return get_mmvq_mmid_max_batch_gcn(type);
|
||||
}
|
||||
}
|
||||
return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);
|
||||
|
||||
@@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
|
||||
@@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);
|
||||
|
||||
@@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);
|
||||
|
||||
@@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8);
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(512, 512);
|
||||
@@ -3,7 +3,7 @@
|
||||
from glob import glob
|
||||
import os
|
||||
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576]
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
|
||||
|
||||
@@ -35,7 +35,7 @@ TYPES_MMQ = [
|
||||
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
||||
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
|
||||
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
|
||||
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4"
|
||||
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4", "GGML_TYPE_NVFP4"
|
||||
]
|
||||
|
||||
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
@@ -83,6 +83,8 @@ for ncols in [8, 16, 32, 64]:
|
||||
continue
|
||||
if head_size_kq == 72:
|
||||
continue
|
||||
if head_size_kq == 512 and ncols2 not in (4, 8):
|
||||
continue
|
||||
if head_size_kq != 576 and ncols2 in (16, 32):
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 not in (4, 16, 32):
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmq.cuh"
|
||||
|
||||
DECL_MMQ_CASE(GGML_TYPE_NVFP4);
|
||||
@@ -2231,6 +2231,22 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session *
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_UNUSED(sess);
|
||||
return true;
|
||||
}
|
||||
|
||||
enum dspqbuf_type {
|
||||
DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0,
|
||||
DSPQBUF_TYPE_CPU_WRITE_DSP_READ,
|
||||
@@ -2399,6 +2415,16 @@ static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bu
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_cumsum_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_CUMSUM;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_GET_ROWS;
|
||||
|
||||
@@ -2780,6 +2806,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
ggml_hexagon_dispatch_op<init_ssm_conv_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_CUMSUM:
|
||||
ggml_hexagon_dispatch_op<init_cumsum_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
default:
|
||||
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
|
||||
}
|
||||
@@ -3254,6 +3284,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_ssm_conv(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_CUMSUM:
|
||||
supp = ggml_hexagon_supported_cumsum(sess, op);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ add_library(${HTP_LIB} SHARED
|
||||
repeat-ops.c
|
||||
argsort-ops.c
|
||||
ssm-conv.c
|
||||
cumsum-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
|
||||
@@ -164,6 +164,12 @@ static void quicksort_values_indices_desc(float * values, int32_t * indices, int
|
||||
if (i < right) quicksort_values_indices_desc(values, indices, i, right);
|
||||
}
|
||||
|
||||
// LUT for ramp initialization of argsort output (first 32 members)
|
||||
int32_t argosrt_ramp_lut[32] __attribute__((aligned(VLEN))) = {
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
|
||||
};
|
||||
|
||||
static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
|
||||
struct htp_ops_context * octx = actx->octx;
|
||||
@@ -205,8 +211,12 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
|
||||
// Padded to 128 bytes.
|
||||
|
||||
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
|
||||
size_t num_vec_ind_values = hmx_ceil_div(ne00, VLEN/(sizeof(int32_t)));
|
||||
float * values_buf = (float *) spad;
|
||||
int32_t * indices_buf = (int32_t *) (spad + values_size);
|
||||
HVX_Vector * indices_buf_vec = (HVX_Vector *) (spad + values_size);
|
||||
const HVX_Vector ind_init_vec = *(HVX_Vector *)argosrt_ramp_lut;
|
||||
const HVX_Vector ind_diff_vec = Q6_V_vsplat_R(32);
|
||||
|
||||
for (uint32_t r = start_row; r < end_row; r++) {
|
||||
uint32_t src_offset = r * nb01;
|
||||
@@ -218,9 +228,11 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
|
||||
hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
|
||||
hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);
|
||||
|
||||
// Initialize indices
|
||||
for (uint32_t j = 0; j < ne00; j++) {
|
||||
indices_buf[j] = j;
|
||||
// Initialize indices - Start with values 0..31, add 32 for additional vec iterations
|
||||
HVX_Vector curr_ind_vec = ind_init_vec;
|
||||
for (uint32_t j_vec = 0; j_vec < num_vec_ind_values; j_vec++) {
|
||||
indices_buf_vec[j_vec] = curr_ind_vec;
|
||||
curr_ind_vec = Q6_Vw_vadd_VwVw(curr_ind_vec, ind_diff_vec);
|
||||
}
|
||||
|
||||
// Sort values and mirror swaps to indices
|
||||
|
||||
267
ggml/src/ggml-hexagon/htp/cumsum-ops.c
Normal file
267
ggml/src/ggml-hexagon/htp/cumsum-ops.c
Normal file
@@ -0,0 +1,267 @@
|
||||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-types.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "hex-dma.h"
|
||||
|
||||
#define htp_cumsum_tensors_preamble \
|
||||
struct htp_tensor * restrict src0 = &octx->src0; \
|
||||
struct htp_tensor * restrict dst = &octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct htp_cumsum_context {
|
||||
struct htp_ops_context * octx;
|
||||
size_t src_row_size;
|
||||
size_t dst_row_size;
|
||||
size_t src_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
uint32_t rows_per_thread;
|
||||
uint32_t total_rows;
|
||||
};
|
||||
|
||||
#define htp_cumsum_preamble \
|
||||
struct htp_cumsum_context * cctx = (struct htp_cumsum_context *) data; \
|
||||
struct htp_ops_context * octx = cctx->octx; \
|
||||
htp_cumsum_tensors_preamble; \
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HVX prefix scan helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#if __HVX_ARCH__ > 75
|
||||
static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) {
|
||||
return Q6_Vsf_vadd_VsfVsf(a, b);
|
||||
}
|
||||
#else
|
||||
static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) {
|
||||
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b));
|
||||
}
|
||||
#endif // __HVX_ARCH__ > 75
|
||||
|
||||
static inline HVX_Vector hvx_prefix_scan_f32(HVX_Vector v, HVX_Vector carry_in) {
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
|
||||
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 4));
|
||||
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 8));
|
||||
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 16));
|
||||
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 32));
|
||||
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 64));
|
||||
v = hvx_cumsum_vadd(v, carry_in);
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_splat_last_f32(HVX_Vector v) {
|
||||
return hvx_vec_repl4(Q6_V_vror_VR(v, 124));
|
||||
}
|
||||
|
||||
static inline void hvx_cumsum_row_f32(const float * restrict src, float * restrict dst, uint32_t n) {
|
||||
const uint32_t nvec = n / VLEN_FP32;
|
||||
const uint32_t nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector carry = Q6_V_vsplat_R(0);
|
||||
|
||||
for (uint32_t i = 0; i < nvec; i++) {
|
||||
HVX_Vector v = *((const HVX_UVector *) (src + i * VLEN_FP32));
|
||||
v = hvx_prefix_scan_f32(v, carry);
|
||||
hvx_vec_store_u(dst + i * VLEN_FP32, VLEN, v);
|
||||
carry = hvx_splat_last_f32(v);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
float acc = hvx_vec_get_f32(carry);
|
||||
const float * src_tail = src + nvec * VLEN_FP32;
|
||||
float * dst_tail = dst + nvec * VLEN_FP32;
|
||||
for (uint32_t i = 0; i < nloe; i++) {
|
||||
acc += src_tail[i];
|
||||
dst_tail[i] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per thread worker: Double-buffered DMA
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static void cumsum_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_cumsum_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint32_t ir0 = cctx->rows_per_thread * ith;
|
||||
const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows);
|
||||
|
||||
if (ir0 >= ir1) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t src_row_size = cctx->src_row_size;
|
||||
const size_t dst_row_size = cctx->dst_row_size;
|
||||
const size_t src_row_size_aligned = cctx->src_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = cctx->dst_row_size_aligned;
|
||||
|
||||
const uint8_t * src_data = (const uint8_t *) src0->data;
|
||||
uint8_t * dst_data = (uint8_t *) dst->data;
|
||||
|
||||
uint8_t * src_spad = octx->src0_spad.data + (ith * src_row_size_aligned * 2);
|
||||
uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned * 2);
|
||||
|
||||
for (uint32_t ir = ir0, spad_idx = 0; ir < ir1 && spad_idx < 2; ir++, spad_idx++) {
|
||||
// Dummy dst writeback to establish queue ordering
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(dst_data, dst_spad + (spad_idx * dst_row_size_aligned)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src_spad + (spad_idx * src_row_size_aligned),
|
||||
src_data + (ir * src_row_size)),
|
||||
src_row_size_aligned, src_row_size, 1);
|
||||
}
|
||||
|
||||
for (uint32_t ir = ir0; ir < ir1; ir++) {
|
||||
float * dst_spad_row = (float *) dma_queue_pop(dma_queue).src;
|
||||
float * src_spad_row = (float *) dma_queue_pop(dma_queue).dst;
|
||||
|
||||
hvx_cumsum_row_f32(src_spad_row, dst_spad_row, ne00);
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(dst_data + (ir * dst_row_size), (uint8_t *) dst_spad_row),
|
||||
dst_row_size, dst_row_size_aligned, 1);
|
||||
|
||||
const uint32_t next_row = ir + 2;
|
||||
if (next_row < ir1) {
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr((uint8_t *) src_spad_row, src_data + (next_row * src_row_size)),
|
||||
src_row_size_aligned, src_row_size, 1);
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "cumsum-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n",
|
||||
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per thread worker: Direct HVX (no DMA)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static void cumsum_thread_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_cumsum_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint8_t * src_data = (const uint8_t *) src0->data;
|
||||
uint8_t * dst_data = (uint8_t *) dst->data;
|
||||
|
||||
const uint32_t ir0 = cctx->rows_per_thread * ith;
|
||||
const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows);
|
||||
|
||||
for (uint32_t ir = ir0; ir < ir1; ir++) {
|
||||
const float * restrict src_row = (const float *) (src_data + ir * cctx->src_row_size);
|
||||
float * restrict dst_row = (float *) (dst_data + ir * cctx->dst_row_size);
|
||||
hvx_cumsum_row_f32(src_row, dst_row, ne00);
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "cumsum-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n",
|
||||
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
int op_cumsum_f32(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
const uint32_t total_rows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, total_rows);
|
||||
|
||||
const size_t src_row_size = src0->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
const size_t src_row_size_aligned = hex_round_up(src_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
// 2 ping-pong buffers per thread for src and dst
|
||||
const size_t spad_per_thread = 2 * (src_row_size_aligned + dst_row_size_aligned);
|
||||
|
||||
octx->src0_spad.size_per_thread = src_row_size_aligned * 2;
|
||||
octx->dst_spad.size_per_thread = dst_row_size_aligned * 2;
|
||||
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
|
||||
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
struct htp_cumsum_context cctx = {
|
||||
.octx = octx,
|
||||
.src_row_size = src_row_size,
|
||||
.dst_row_size = dst_row_size,
|
||||
.src_row_size_aligned = src_row_size_aligned,
|
||||
.dst_row_size_aligned = dst_row_size_aligned,
|
||||
.rows_per_thread = (total_rows + n_threads - 1) / n_threads,
|
||||
.total_rows = total_rows,
|
||||
};
|
||||
|
||||
if (octx->ctx->vtcm_size < spad_per_thread * n_threads) {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32, &cctx, n_threads);
|
||||
} else {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32_dma, &cctx, n_threads);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
int op_cumsum(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
switch (dst->type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = op_cumsum_f32(octx);
|
||||
break;
|
||||
default:
|
||||
err = HTP_STATUS_NO_SUPPORT;
|
||||
break;
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
@@ -75,6 +75,7 @@ enum htp_op {
|
||||
HTP_OP_SUM_ROWS,
|
||||
HTP_OP_SSM_CONV,
|
||||
HTP_OP_REPEAT,
|
||||
HTP_OP_CUMSUM,
|
||||
INVALID
|
||||
};
|
||||
|
||||
|
||||
@@ -60,5 +60,6 @@ int op_cpy(struct htp_ops_context * octx);
|
||||
int op_repeat(struct htp_ops_context * octx);
|
||||
int op_argsort(struct htp_ops_context * octx);
|
||||
int op_ssm_conv(struct htp_ops_context * octx);
|
||||
int op_cumsum(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_OPS_H */
|
||||
|
||||
@@ -16,8 +16,10 @@
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#define HVX_OP_MUL_F16(a, b) Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b))
|
||||
#else
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#define HVX_OP_MUL_F16(a, b) Q6_Vhf_vmpy_VhfVhf(a, b)
|
||||
#endif
|
||||
|
||||
// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32.
|
||||
@@ -43,46 +45,67 @@ static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX
|
||||
return res;
|
||||
}
|
||||
|
||||
#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP16; \
|
||||
const uint32_t nloe = n % VLEN_FP16; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
|
||||
} \
|
||||
// Variant for <v79: Use pre-computed f16 reciprocal constant
|
||||
static inline HVX_Vector hvx_div_mul_f16_const_using_f16(HVX_Vector vec1_hf, HVX_Vector const_inv_hf) {
|
||||
// Multiply by pre-computed f16 reciprocal constant
|
||||
return HVX_OP_MUL_F16(vec1_hf, const_inv_hf);
|
||||
}
|
||||
|
||||
#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP16; \
|
||||
const uint32_t nloe = n % VLEN_FP16; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector res; \
|
||||
if (__HVX_ARCH__ < 79) { \
|
||||
res = hvx_div_mul_f16_const_using_f16(vsrc[i], val_vec_f16); \
|
||||
} else { \
|
||||
res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
|
||||
} \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector res; \
|
||||
if (__HVX_ARCH__ < 79) { \
|
||||
res = hvx_div_mul_f16_const_using_f16(vsrc[i], val_vec_f16); \
|
||||
} else { \
|
||||
res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
|
||||
} \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src % 128 == 0);
|
||||
hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
static inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
static inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
|
||||
assert((uintptr_t) src % 128 == 0);
|
||||
hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
static inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
|
||||
hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
@@ -128,13 +151,25 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v
|
||||
return recip;
|
||||
}
|
||||
|
||||
// Hybrid approach: f16 reciprocal for <v79, f32 precision for >=v79
|
||||
static inline HVX_Vector hvx_vec_hybrid_div_f16(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector f16_nan_inf_mask, HVX_Vector vec_hf_one_1_0) {
|
||||
#if __HVX_ARCH__ < 79
|
||||
// For older architectures, use f16 reciprocal to avoid NaN/-inf issues
|
||||
HVX_Vector vec2_inv = hvx_vec_inverse_f16_guard(vec2, f16_nan_inf_mask);
|
||||
return HVX_OP_MUL_F16(vec1, vec2_inv);
|
||||
#else
|
||||
return hvx_vec_div_f16_using_f32(vec1, vec2, f32_nan_inf_mask, vec_hf_one_1_0);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src0_type * restrict vsrc0 = (src0_type *) src0; \
|
||||
src1_type * restrict vsrc1 = (src1_type *) src1; \
|
||||
\
|
||||
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
|
||||
const HVX_Vector f32_nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
|
||||
const HVX_Vector f16_nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \
|
||||
const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP16; \
|
||||
@@ -144,11 +179,15 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
|
||||
HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \
|
||||
f32_nan_inf_mask, f16_nan_inf_mask, \
|
||||
hf_one); \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
|
||||
HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \
|
||||
f32_nan_inf_mask, f16_nan_inf_mask, \
|
||||
hf_one); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
|
||||
} \
|
||||
} while(0)
|
||||
@@ -247,5 +286,6 @@ HVX_DIV_DISPATCHER(hvx_div_f32)
|
||||
HVX_DIV_DISPATCHER(hvx_div_f16)
|
||||
|
||||
#undef HVX_OP_MUL_F32
|
||||
#undef HVX_OP_MUL_F16
|
||||
|
||||
#endif // HVX_DIV_H
|
||||
|
||||
@@ -860,6 +860,41 @@ static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req *
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_cumsum_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We've written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[1].fd;
|
||||
rsp_bufs[0].ptr = bufs[1].ptr;
|
||||
rsp_bufs[0].offset = bufs[1].offset;
|
||||
rsp_bufs[0].size = bufs[1].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_cumsum(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_activations_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
@@ -1474,6 +1509,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
proc_ssm_conv_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_CUMSUM:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad cumsum-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_cumsum_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unknown Op %u", req.op);
|
||||
break;
|
||||
|
||||
@@ -67,34 +67,61 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
||||
uint8_t * restrict pad,
|
||||
const int num_elems,
|
||||
float epsilon) {
|
||||
(void)pad;
|
||||
|
||||
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
|
||||
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
|
||||
const int nvec = num_elems / VLEN_FP32; // number of full vectors
|
||||
const int nloe = num_elems % VLEN_FP32; // leftover elements
|
||||
|
||||
// Compute sum of squares for full vectors
|
||||
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
|
||||
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
||||
|
||||
int step_of_1 = num_elems >> 5;
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
}
|
||||
|
||||
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes
|
||||
// Handle tail elements using vectorized ops with masking
|
||||
if (nloe > 0) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
}
|
||||
|
||||
// Reduce HVX sum
|
||||
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
||||
|
||||
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
||||
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
||||
HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
|
||||
HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
|
||||
|
||||
// Scale full vectors
|
||||
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
||||
v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
|
||||
v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
|
||||
}
|
||||
|
||||
// Handle tail elements using vectorized ops with masking
|
||||
if (nloe > 0) {
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
||||
HVX_Vector result = Q6_Vsf_equals_Vqf32(v2);
|
||||
|
||||
// Store with masking to avoid overwriting memory beyond the tensor
|
||||
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9612,6 +9612,9 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t
|
||||
cl_mem B_image1d;
|
||||
cl_mem B_sub_buffer;
|
||||
cl_mem S_image1d;
|
||||
// for B transpose
|
||||
cl_mem B_image1d_trans = nullptr;
|
||||
cl_mem B_d = nullptr;
|
||||
|
||||
cl_mem D_image1d;
|
||||
cl_mem D_sub_buffer;
|
||||
@@ -9703,9 +9706,6 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t
|
||||
global_work_size[2] = 1;
|
||||
} else {
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
cl_mem B_image1d_trans = nullptr;
|
||||
// for B transpose
|
||||
cl_mem B_d = nullptr;
|
||||
int padding;
|
||||
|
||||
//how many extra elements beyond multiple of 8
|
||||
@@ -9800,6 +9800,12 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t
|
||||
CL_CHECK(clReleaseMemObject(S_image1d));
|
||||
CL_CHECK(clReleaseMemObject(D_sub_buffer));
|
||||
CL_CHECK(clReleaseMemObject(D_image1d));
|
||||
if (B_image1d_trans) {
|
||||
CL_CHECK(clReleaseMemObject(B_image1d_trans));
|
||||
}
|
||||
if (B_d) {
|
||||
CL_CHECK(clReleaseMemObject(B_d));
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(backend);
|
||||
GGML_UNUSED(src0);
|
||||
|
||||
@@ -1009,8 +1009,8 @@ public:
|
||||
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
|
||||
|
||||
struct stored_graph {
|
||||
ggml_context_ptr ctx_ptr;
|
||||
ggml_cgraph * graph;
|
||||
std::vector<uint8_t> buffer;
|
||||
ggml_cgraph * graph;
|
||||
};
|
||||
|
||||
private:
|
||||
@@ -1518,10 +1518,12 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
||||
LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
|
||||
|
||||
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
|
||||
|
||||
if (stored_graphs[device].buffer.size() < buf_size) {
|
||||
stored_graphs[device].buffer.resize(buf_size);
|
||||
}
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_size,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.mem_buffer =*/ stored_graphs[device].buffer.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
||||
@@ -1551,7 +1553,6 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
||||
}
|
||||
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
||||
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
|
||||
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
|
||||
stored_graphs[device].graph = graph;
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-sycl.h"
|
||||
#include "presets.hpp"
|
||||
#include "type.hpp"
|
||||
#include "sycl_hw.hpp"
|
||||
|
||||
namespace syclexp = sycl::ext::oneapi::experimental;
|
||||
@@ -965,4 +966,10 @@ static T block_reduce(T val, T * shared_vals, int block_size_template) {
|
||||
return val;
|
||||
}
|
||||
|
||||
static __dpct_inline__ float ggml_sycl_ue4m3_to_fp32(uint8_t x) {
|
||||
const uint32_t bits = x * (x != 0x7F && x != 0xFF);
|
||||
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
|
||||
return static_cast<float>(xf) / 2;
|
||||
}
|
||||
|
||||
#endif // GGML_SYCL_COMMON_HPP
|
||||
|
||||
@@ -482,6 +482,18 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t
|
||||
});
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(k % QK_NVFP4 == 0);
|
||||
const int nb = k / QK_NVFP4;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_nvfp4(vx, y, k);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
@@ -641,6 +653,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
||||
return dequantize_row_iq4_nl_sycl;
|
||||
case GGML_TYPE_MXFP4:
|
||||
return dequantize_row_mxfp4_sycl;
|
||||
case GGML_TYPE_NVFP4:
|
||||
return dequantize_row_nvfp4_sycl;
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_sycl<float>;
|
||||
#ifdef GGML_SYCL_HAS_BF16
|
||||
@@ -648,6 +662,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
||||
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
|
||||
#endif
|
||||
default:
|
||||
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
@@ -708,6 +723,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
||||
return dequantize_row_iq4_nl_sycl;
|
||||
case GGML_TYPE_MXFP4:
|
||||
return dequantize_row_mxfp4_sycl;
|
||||
case GGML_TYPE_NVFP4:
|
||||
return dequantize_row_nvfp4_sycl;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_unary_sycl<sycl::half>;
|
||||
#ifdef GGML_SYCL_HAS_BF16
|
||||
@@ -715,6 +732,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
||||
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
|
||||
#endif
|
||||
default:
|
||||
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -838,4 +838,36 @@ static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restr
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_block_nvfp4(
|
||||
const void * __restrict__ vx,
|
||||
dst_t * __restrict__ yy,
|
||||
const int64_t ne) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
|
||||
const int64_t base = i * QK_NVFP4;
|
||||
if (base >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const block_nvfp4 * x = (const block_nvfp4 *) vx;
|
||||
const block_nvfp4 & xb = x[i];
|
||||
|
||||
const int sub = tid / (QK_NVFP4_SUB / 2);
|
||||
const int j = tid % (QK_NVFP4_SUB / 2);
|
||||
|
||||
const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]);
|
||||
const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j];
|
||||
|
||||
const int64_t y0 = base + sub * QK_NVFP4_SUB + j;
|
||||
const int64_t y1 = y0 + QK_NVFP4_SUB / 2;
|
||||
|
||||
yy[y0] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q & 0x0F]);
|
||||
yy[y1] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q >> 4]);
|
||||
}
|
||||
|
||||
|
||||
#endif // GGML_SYCL_DEQUANTIZE_HPP
|
||||
|
||||
@@ -1252,6 +1252,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
constexpr int cols_per_block = ncols2*2;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
|
||||
@@ -569,9 +569,15 @@ static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw()));
|
||||
|
||||
SYCL_CHECK(CHECK_TRY_ERROR((*stream)
|
||||
.memset(ctx->dev_ptr, value, buffer->size)
|
||||
.wait()));
|
||||
constexpr size_t MAX_CHUNK = 2ULL << 30; // 2 GiB
|
||||
for (size_t off = 0; off < buffer->size; off += MAX_CHUNK) {
|
||||
size_t chunk = std::min(buffer->size - off, MAX_CHUNK);
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(
|
||||
(*stream)
|
||||
.memset(static_cast<char*>(ctx->dev_ptr) + off, value, chunk)
|
||||
.wait()
|
||||
));
|
||||
}
|
||||
}
|
||||
catch (sycl::exception const &exc) {
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||
|
||||
@@ -613,6 +613,23 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_NVFP4 == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
|
||||
{
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
float *dst, const int ncols,
|
||||
@@ -1145,8 +1162,11 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
||||
case GGML_TYPE_MXFP4:
|
||||
mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_NVFP4:
|
||||
mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type));
|
||||
}
|
||||
}
|
||||
GGML_UNUSED(src1);
|
||||
|
||||
112
ggml/src/ggml-sycl/type.hpp
Normal file
112
ggml/src/ggml-sycl/type.hpp
Normal file
@@ -0,0 +1,112 @@
|
||||
#pragma once
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
inline uint8_t float_to_e4m3(float f)
|
||||
{
|
||||
if (sycl::isnan(f)) {
|
||||
return 0x7F; // Canonical NaN (positive)
|
||||
}
|
||||
|
||||
uint32_t bits = sycl::bit_cast<uint32_t>(f);
|
||||
uint32_t sign = (bits >> 31) & 0x1u;
|
||||
uint32_t exp = (bits >> 23) & 0xFFu;
|
||||
uint32_t mant = bits & 0x7FFFFFu;
|
||||
|
||||
// Zero
|
||||
if (exp == 0 && mant == 0) {
|
||||
return static_cast<uint8_t>(sign << 7);
|
||||
}
|
||||
|
||||
// Extract biased exponent and mantissa for FP8
|
||||
int e = static_cast<int>(exp) - 127; // true exponent (IEEE bias 127)
|
||||
uint32_t m = mant;
|
||||
|
||||
// Handle very large values → NaN (NVIDIA behavior for E4M3)
|
||||
if (e > 7) { // max exponent for E4M3 is 7 (biased 14)
|
||||
return static_cast<uint8_t>((sign << 7) | 0x7F);
|
||||
}
|
||||
|
||||
// Handle subnormals and normal numbers
|
||||
if (e < -6) { // smallest normal exponent is -6
|
||||
// Subnormal in FP8: shift mantissa right
|
||||
int shift = -6 - e;
|
||||
m = (m | 0x800000u) >> (shift + 1); // +1 because we lose the implicit 1 position
|
||||
if (shift > 23) m = 0;
|
||||
} else {
|
||||
// Normal number: adjust exponent bias from 127 to 7
|
||||
int new_exp = e + 7;
|
||||
m = (m >> 20) & 0x7u; // take top 3 mantissa bits (after implicit 1)
|
||||
m |= (static_cast<uint32_t>(new_exp) << 3);
|
||||
}
|
||||
|
||||
// Round-to-nearest-even (simple guard + round bit)
|
||||
// For better accuracy you can add sticky bit, but this is sufficient for most use cases
|
||||
uint32_t round_bit = (mant >> 19) & 0x1u; // bit after the 3 mantissa bits
|
||||
if (round_bit) {
|
||||
m += 1;
|
||||
// Carry into exponent if mantissa overflows
|
||||
if ((m & 0x8u) != 0) {
|
||||
m = (m & 0x7u) | ((m & 0x38u) << 1); // simple carry handling
|
||||
// If exponent overflows after carry → NaN
|
||||
if ((m >> 3) > 14) {
|
||||
return static_cast<uint8_t>((sign << 7) | 0x7F);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint8_t result = static_cast<uint8_t>((sign << 7) | (m & 0x7F));
|
||||
return result;
|
||||
}
|
||||
|
||||
inline float e4m3_to_float(uint8_t x)
|
||||
{
|
||||
if (x == 0) return 0.0f;
|
||||
|
||||
uint8_t sign = (x >> 7) & 0x1u;
|
||||
uint8_t exp = (x >> 3) & 0xFu;
|
||||
uint8_t mant = x & 0x7u;
|
||||
|
||||
// NaN (NVIDIA uses 0x7F / 0xFF as NaN)
|
||||
if (exp == 0xF && mant != 0) {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
if (exp == 0xF) { // 0x7F or 0xFF treated as NaN
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
|
||||
float val;
|
||||
|
||||
if (exp == 0) {
|
||||
// Subnormal
|
||||
val = mant * (1.0f / 8.0f) * sycl::pow(2.0f, -6.0f);
|
||||
} else {
|
||||
// Normal: implicit leading 1 + bias 7
|
||||
val = (1.0f + mant / 8.0f) * sycl::pow(2.0f, static_cast<float>(exp) - 7.0f);
|
||||
}
|
||||
|
||||
return sign ? -val : val;
|
||||
}
|
||||
|
||||
// The actual type definition
|
||||
struct __nv_fp8_e4m3 {
|
||||
uint8_t raw;
|
||||
|
||||
__nv_fp8_e4m3() = default;
|
||||
|
||||
explicit __nv_fp8_e4m3(float f) : raw(float_to_e4m3(f)) {}
|
||||
explicit __nv_fp8_e4m3(sycl::half h) : raw(float_to_e4m3(static_cast<float>(h))) {}
|
||||
|
||||
operator float() const { return e4m3_to_float(raw); }
|
||||
operator sycl::half() const { return static_cast<sycl::half>(static_cast<float>(*this)); }
|
||||
|
||||
// Allow direct access for vector loads/stores
|
||||
operator uint8_t&() { return raw; }
|
||||
operator uint8_t() const { return raw; }
|
||||
};
|
||||
|
||||
using __nv_fp8x2_e4m3 = sycl::vec<__nv_fp8_e4m3, 2>;
|
||||
using __nv_fp8x4_e4m3 = sycl::vec<__nv_fp8_e4m3, 4>;
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
#include "dpct/helper.hpp"
|
||||
#include "ggml.h"
|
||||
#include "type.hpp"
|
||||
#include "quants.hpp"
|
||||
|
||||
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
|
||||
@@ -31,6 +32,18 @@ static __dpct_inline__ int get_int_b1(const void * x, const int & i32) {
|
||||
return x32;
|
||||
}
|
||||
|
||||
static __dpct_inline__ int get_int_b2(const void * x, const int & i32) {
|
||||
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
|
||||
|
||||
int x32 = x16[2*i32 + 0] << 0;
|
||||
x32 |= x16[2*i32 + 1] << 16;
|
||||
|
||||
return x32;
|
||||
}
|
||||
|
||||
static __dpct_inline__ int get_int_b4(const void * x, const int & i32) {
|
||||
return ((const int *) x)[i32]; // assume at least 4 byte alignment
|
||||
}
|
||||
|
||||
static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
|
||||
const uint16_t* x16 =
|
||||
@@ -755,6 +768,35 @@ static __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq,
|
||||
return d * sumi;
|
||||
}
|
||||
|
||||
#define VDR_NVFP4_Q8_1_MMVQ 4
|
||||
#define VDR_NVFP4_Q8_1_MMQ 8
|
||||
|
||||
static __dpct_inline__ float vec_dot_nvfp4_q8_1(const void * __restrict__ vbq,
|
||||
const block_q8_1 * __restrict__ bq8_1,
|
||||
const int32_t & iqs) {
|
||||
const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq;
|
||||
float sum = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) {
|
||||
const int32_t iqs0 = iqs + 2*i;
|
||||
const int32_t iqs1 = iqs0 + 1;
|
||||
const int32_t is = iqs0 >> 1;
|
||||
const sycl::int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4);
|
||||
const sycl::int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4);
|
||||
const block_q8_1 * bq8 = bq8_1 + (is >> 1);
|
||||
const int32_t i8 = ((is & 1) << 2);
|
||||
|
||||
int sumi = ggml_sycl_dp4a(v0.x(), get_int_b4(bq8->qs, i8 + 0), 0);
|
||||
sumi = ggml_sycl_dp4a(v0.y(), get_int_b4(bq8->qs, i8 + 2), sumi);
|
||||
sumi = ggml_sycl_dp4a(v1.x(), get_int_b4(bq8->qs, i8 + 1), sumi);
|
||||
sumi = ggml_sycl_dp4a(v1.y(), get_int_b4(bq8->qs, i8 + 3), sumi);
|
||||
|
||||
const float d = ggml_sycl_ue4m3_to_fp32(bq4->d[is]) * (bq8->ds)[0];
|
||||
sum += d * float(sumi);
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
static __dpct_inline__ float
|
||||
vec_dot_q5_0_q8_1(const void *__restrict__ vbq,
|
||||
|
||||
@@ -95,6 +95,12 @@ struct ggml_webgpu_generic_shader_decisions {
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_processed_shader {
|
||||
std::string wgsl;
|
||||
std::string variant;
|
||||
std::shared_ptr<void> decisions;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_conv_shader_decisions {
|
||||
uint32_t block_size;
|
||||
uint32_t tokens_per_wg;
|
||||
@@ -384,11 +390,12 @@ struct ggml_webgpu_flash_attn_pipeline_key {
|
||||
bool has_mask;
|
||||
bool has_sinks;
|
||||
bool uses_logit_softcap;
|
||||
bool use_vec;
|
||||
|
||||
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
|
||||
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
|
||||
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
||||
uses_logit_softcap == other.uses_logit_softcap;
|
||||
uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -402,6 +409,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.has_mask);
|
||||
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
||||
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
||||
ggml_webgpu_hash_combine(seed, key.use_vec);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
@@ -421,6 +429,121 @@ struct ggml_webgpu_flash_attn_shader_decisions {
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
|
||||
// Keep conservative defaults unless this is the f16 vec-split shape family.
|
||||
if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) {
|
||||
return 1u;
|
||||
}
|
||||
|
||||
// Head-dim specializations used by the tuned vec f16 path.
|
||||
switch (key.head_dim_qk) {
|
||||
case 64:
|
||||
return 2u;
|
||||
case 96:
|
||||
return 4u;
|
||||
case 128:
|
||||
return 1u;
|
||||
case 192:
|
||||
return 2u;
|
||||
case 576:
|
||||
return 2u;
|
||||
default:
|
||||
return 1u;
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
|
||||
uint32_t head_dim_v;
|
||||
uint32_t wg_size;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
||||
ggml_webgpu_hash_combine(seed, key.wg_size);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs,
|
||||
const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) {
|
||||
return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context {
|
||||
ggml_webgpu_flash_attn_vec_reduce_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn_vec_reduce";
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
variant += std::string("_wg") + std::to_string(context.max_wg_size);
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_flash_attn_blk_pipeline_key {
|
||||
uint32_t q_tile;
|
||||
uint32_t kv_tile;
|
||||
|
||||
bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const {
|
||||
return q_tile == other.q_tile && kv_tile == other.kv_tile;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.q_tile);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_tile);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_blk_shader_lib_context {
|
||||
ggml_webgpu_flash_attn_blk_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn_vec_blk";
|
||||
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile));
|
||||
variant += std::string("_qt") + std::to_string(context.key.q_tile);
|
||||
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile));
|
||||
variant += std::string("_kvt") + std::to_string(context.key.kv_tile);
|
||||
|
||||
uint32_t wg_size = 1;
|
||||
while ((wg_size << 1) <= context.max_wg_size) {
|
||||
wg_size <<= 1;
|
||||
}
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
variant += std::string("_wg") + std::to_string(wg_size);
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
return result;
|
||||
}
|
||||
|
||||
// This is exposed because it's necessary in supports_op
|
||||
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
||||
uint32_t kv_tile,
|
||||
@@ -535,6 +658,95 @@ struct ggml_webgpu_mul_mat_shader_decisions {
|
||||
uint32_t mul_mat_wg_size;
|
||||
};
|
||||
|
||||
/** Cpy **/
|
||||
|
||||
struct ggml_webgpu_cpy_pipeline_key {
|
||||
ggml_type src_type;
|
||||
ggml_type dst_type;
|
||||
|
||||
bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const {
|
||||
return src_type == other.src_type && dst_type == other.dst_type;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_cpy_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.src_type);
|
||||
ggml_webgpu_hash_combine(seed, key.dst_type);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Glu **/
|
||||
|
||||
struct ggml_webgpu_glu_pipeline_key {
|
||||
ggml_glu_op glu_op;
|
||||
ggml_type type;
|
||||
bool split;
|
||||
|
||||
bool operator==(const ggml_webgpu_glu_pipeline_key & other) const {
|
||||
return glu_op == other.glu_op && type == other.type && split == other.split;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_glu_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.glu_op);
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.split);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Rope **/
|
||||
|
||||
struct ggml_webgpu_rope_pipeline_key {
|
||||
ggml_type type;
|
||||
bool inplace;
|
||||
bool has_ff;
|
||||
|
||||
bool operator==(const ggml_webgpu_rope_pipeline_key & other) const {
|
||||
return type == other.type && inplace == other.inplace && has_ff == other.has_ff;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_rope_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
ggml_webgpu_hash_combine(seed, key.has_ff);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** SoftMax **/
|
||||
|
||||
struct ggml_webgpu_soft_max_pipeline_key {
|
||||
ggml_type mask_type;
|
||||
bool has_mask;
|
||||
bool has_sink;
|
||||
bool inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const {
|
||||
return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink &&
|
||||
inplace == other.inplace;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_soft_max_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.mask_type);
|
||||
ggml_webgpu_hash_combine(seed, key.has_mask);
|
||||
ggml_webgpu_hash_combine(seed, key.has_sink);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
class ggml_webgpu_shader_lib {
|
||||
wgpu::Device device;
|
||||
pre_wgsl::Preprocessor preprocessor;
|
||||
@@ -570,6 +782,14 @@ class ggml_webgpu_shader_lib {
|
||||
repeat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_pipelines;
|
||||
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
|
||||
flash_attn_vec_reduce_pipelines;
|
||||
std::unordered_map<ggml_webgpu_flash_attn_blk_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_flash_attn_blk_pipeline_key_hash>
|
||||
flash_attn_blk_pipelines;
|
||||
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
|
||||
@@ -582,6 +802,12 @@ class ggml_webgpu_shader_lib {
|
||||
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
|
||||
set_rows_pipelines;
|
||||
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
|
||||
std::unordered_map<ggml_webgpu_cpy_pipeline_key, webgpu_pipeline, ggml_webgpu_cpy_pipeline_key_hash> cpy_pipelines;
|
||||
std::unordered_map<ggml_webgpu_glu_pipeline_key, webgpu_pipeline, ggml_webgpu_glu_pipeline_key_hash> glu_pipelines;
|
||||
std::unordered_map<ggml_webgpu_rope_pipeline_key, webgpu_pipeline, ggml_webgpu_rope_pipeline_key_hash>
|
||||
rope_pipelines;
|
||||
std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
|
||||
soft_max_pipelines;
|
||||
|
||||
public:
|
||||
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
|
||||
@@ -1124,9 +1350,8 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
|
||||
// For fast path we always dequantize from f16 inside the shader
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
defines.push_back("SRC0_INNER_TYPE=u32");
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -1239,9 +1464,8 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
|
||||
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
||||
|
||||
// Use f16 inside the shader for quantized types
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
defines.push_back("SRC0_INNER_TYPE=u32");
|
||||
|
||||
variant += std::string("_") + src0_name;
|
||||
break;
|
||||
@@ -1580,24 +1804,8 @@ class ggml_webgpu_shader_lib {
|
||||
return repeat_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const bool has_mask = context.src3 != nullptr;
|
||||
const bool has_sinks = context.src4 != nullptr;
|
||||
|
||||
bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
|
||||
(context.src1->ne[1] % context.sg_mat_n == 0);
|
||||
|
||||
ggml_webgpu_flash_attn_pipeline_key key = {
|
||||
.kv_type = context.src1->type,
|
||||
.head_dim_qk = (uint32_t) context.src0->ne[0],
|
||||
.head_dim_v = (uint32_t) context.src2->ne[0],
|
||||
.kv_direct = kv_direct,
|
||||
.has_mask = has_mask,
|
||||
.has_sinks = has_sinks,
|
||||
.uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f,
|
||||
};
|
||||
|
||||
auto it = flash_attn_pipelines.find(key);
|
||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) {
|
||||
auto it = flash_attn_pipelines.find(context.key);
|
||||
if (it != flash_attn_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
@@ -1605,7 +1813,7 @@ class ggml_webgpu_shader_lib {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn";
|
||||
|
||||
switch (key.kv_type) {
|
||||
switch (context.key.kv_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("KV_F32");
|
||||
break;
|
||||
@@ -1621,41 +1829,51 @@ class ggml_webgpu_shader_lib {
|
||||
default:
|
||||
GGML_ABORT("Unsupported KV type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_") + ggml_type_name(key.kv_type);
|
||||
variant += std::string("_") + ggml_type_name(context.key.kv_type);
|
||||
|
||||
if (key.has_mask) {
|
||||
if (context.key.has_mask) {
|
||||
defines.push_back("MASK");
|
||||
variant += "_mask";
|
||||
}
|
||||
if (key.has_sinks) {
|
||||
if (context.key.has_sinks) {
|
||||
defines.push_back("SINKS");
|
||||
variant += "_sinks";
|
||||
}
|
||||
if (key.uses_logit_softcap) {
|
||||
if (context.key.uses_logit_softcap) {
|
||||
defines.push_back("LOGIT_SOFTCAP");
|
||||
variant += "_lgsc";
|
||||
}
|
||||
if (key.kv_direct) {
|
||||
if (context.key.kv_direct) {
|
||||
defines.push_back("KV_DIRECT");
|
||||
variant += "_kvdirect";
|
||||
}
|
||||
if (context.key.has_mask && context.key.use_vec) {
|
||||
defines.push_back("BLK");
|
||||
variant += "_blk";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
|
||||
|
||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
||||
|
||||
uint32_t q_tile = context.sg_mat_m;
|
||||
uint32_t kv_tile =
|
||||
std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k,
|
||||
context.wg_mem_limit_bytes, context.max_subgroup_size }),
|
||||
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||
if (key.kv_direct) {
|
||||
uint32_t q_tile = context.sg_mat_m;
|
||||
uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
|
||||
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||
if (context.key.use_vec) {
|
||||
q_tile = 1;
|
||||
kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context)));
|
||||
kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
||||
const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key);
|
||||
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
|
||||
}
|
||||
if (context.key.kv_direct) {
|
||||
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
||||
kv_tile -= context.sg_mat_n;
|
||||
}
|
||||
@@ -1664,19 +1882,281 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
|
||||
|
||||
uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||
uint32_t wg_size = 0;
|
||||
if (context.key.use_vec) {
|
||||
wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
|
||||
} else {
|
||||
wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||
}
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_flash_attn, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
|
||||
decisions->q_tile = q_tile;
|
||||
decisions->kv_tile = kv_tile;
|
||||
decisions->wg_size = wg_size;
|
||||
const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
|
||||
webgpu_pipeline pipeline =
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
|
||||
auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
|
||||
decisions->q_tile = q_tile;
|
||||
decisions->kv_tile = kv_tile;
|
||||
decisions->wg_size = wg_size;
|
||||
pipeline.context = decisions;
|
||||
flash_attn_pipelines[context.key] = pipeline;
|
||||
return flash_attn_pipelines[context.key];
|
||||
}
|
||||
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
flash_attn_pipelines[key] = pipeline;
|
||||
return flash_attn_pipelines[key];
|
||||
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
|
||||
auto it = flash_attn_blk_pipelines.find(context.key);
|
||||
if (it != flash_attn_blk_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context);
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
|
||||
flash_attn_blk_pipelines[context.key] = pipeline;
|
||||
return flash_attn_blk_pipelines[context.key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_vec_reduce_pipeline(
|
||||
const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
|
||||
auto it = flash_attn_vec_reduce_pipelines.find(context.key);
|
||||
if (it != flash_attn_vec_reduce_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context);
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
|
||||
flash_attn_vec_reduce_pipelines[context.key] = pipeline;
|
||||
return flash_attn_vec_reduce_pipelines[context.key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_cpy_pipeline_key key = {
|
||||
.src_type = context.src0->type,
|
||||
.dst_type = context.dst->type,
|
||||
};
|
||||
|
||||
auto it = cpy_pipelines.find(key);
|
||||
if (it != cpy_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "cpy";
|
||||
|
||||
switch (key.src_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("SRC_F32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("SRC_F16");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported src type for cpy shader");
|
||||
}
|
||||
|
||||
switch (key.dst_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("DST_F32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("DST_F16");
|
||||
variant += "_f16";
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
defines.push_back("DST_I32");
|
||||
variant += "_i32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported dst type for cpy shader");
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_cpy, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
cpy_pipelines[key] = pipeline;
|
||||
return cpy_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_glu_pipeline_key key = {
|
||||
.glu_op = ggml_get_glu_op(context.dst),
|
||||
.type = context.dst->type,
|
||||
.split = (context.src1 != nullptr),
|
||||
};
|
||||
|
||||
auto it = glu_pipelines.find(key);
|
||||
if (it != glu_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "glu";
|
||||
|
||||
switch (key.glu_op) {
|
||||
case GGML_GLU_OP_REGLU:
|
||||
defines.push_back("OP_REGLU");
|
||||
variant += "_reglu";
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
defines.push_back("OP_GEGLU");
|
||||
variant += "_geglu";
|
||||
break;
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
defines.push_back("OP_SWIGLU");
|
||||
variant += "_swiglu";
|
||||
break;
|
||||
case GGML_GLU_OP_SWIGLU_OAI:
|
||||
defines.push_back("OP_SWIGLU_OAI");
|
||||
variant += "_swiglu_oai";
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
defines.push_back("OP_GEGLU_ERF");
|
||||
variant += "_geglu_erf";
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
defines.push_back("OP_GEGLU_QUICK");
|
||||
variant += "_geglu_quick";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported GLU op");
|
||||
}
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("TYPE_F32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("TYPE_F16");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for GLU shader");
|
||||
}
|
||||
|
||||
if (key.split) {
|
||||
variant += "_split";
|
||||
} else {
|
||||
defines.push_back("NO_SPLIT");
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_glu, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
glu_pipelines[key] = pipeline;
|
||||
return glu_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_rope_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.inplace = context.inplace,
|
||||
.has_ff = (context.src2 != nullptr),
|
||||
};
|
||||
|
||||
auto it = rope_pipelines.find(key);
|
||||
if (it != rope_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "rope";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("TYPE_F32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("TYPE_F16");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for ROPE shader");
|
||||
}
|
||||
|
||||
if (key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
if (key.has_ff) {
|
||||
defines.push_back("FF_FUNC");
|
||||
variant += "_ff";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_rope, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
rope_pipelines[key] = pipeline;
|
||||
return rope_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_soft_max_pipeline_key key = {
|
||||
.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32,
|
||||
.has_mask = (context.src1 != nullptr),
|
||||
.has_sink = (context.src2 != nullptr),
|
||||
.inplace = context.inplace,
|
||||
};
|
||||
|
||||
auto it = soft_max_pipelines.find(key);
|
||||
if (it != soft_max_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "soft_max";
|
||||
|
||||
if (key.has_mask) {
|
||||
defines.push_back("HAS_MASK");
|
||||
switch (key.mask_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("MASK_F32");
|
||||
variant += "_mask_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("MASK_F16");
|
||||
variant += "_mask_f16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for SOFT_MAX shader");
|
||||
}
|
||||
}
|
||||
|
||||
if (key.has_sink) {
|
||||
defines.push_back("HAS_SINK");
|
||||
variant += "_sink";
|
||||
}
|
||||
|
||||
if (key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_soft_max, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
soft_max_pipelines[key] = pipeline;
|
||||
return soft_max_pipelines[key];
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,30 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef U32_DEQUANT_HELPERS
|
||||
fn load_src0_u16_at(byte_offset: u32) -> u32 {
|
||||
let word = src0[byte_offset / 4u];
|
||||
let shift = (byte_offset & 2u) * 8u;
|
||||
return (word >> shift) & 0xFFFFu;
|
||||
}
|
||||
|
||||
fn load_src0_u32_at(byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4u;
|
||||
let shift = (byte_offset & 3u) * 8u;
|
||||
let lo = src0[word_idx];
|
||||
if (shift == 0u) {
|
||||
return lo;
|
||||
}
|
||||
let hi = src0[word_idx + 1u];
|
||||
return (lo >> shift) | (hi << (32u - shift));
|
||||
}
|
||||
|
||||
fn load_src0_f16_at(byte_offset: u32) -> f16 {
|
||||
let packed = unpack2x16float(load_src0_u16_at(byte_offset));
|
||||
return f16(packed[0]);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q4_0_T
|
||||
struct q4_0 {
|
||||
d: f16,
|
||||
|
||||
@@ -1,66 +1,41 @@
|
||||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC_TYPE": "f32",
|
||||
"DST_TYPE": "f32"
|
||||
}
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC_TYPE": "f32",
|
||||
"DST_TYPE": "i32"
|
||||
}
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC_TYPE": "f32",
|
||||
"DST_TYPE": "f16"
|
||||
}
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC_TYPE": "f16",
|
||||
"DST_TYPE": "f16"
|
||||
}
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC_TYPE": "f16",
|
||||
"DST_TYPE": "f32"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(SHADER)
|
||||
enable f16;
|
||||
|
||||
#ifdef SRC_F32
|
||||
#define SRC_TYPE f32
|
||||
#elif defined(SRC_F16)
|
||||
#define SRC_TYPE f16
|
||||
#endif
|
||||
|
||||
#ifdef DST_F32
|
||||
#define DST_TYPE f32
|
||||
#elif defined(DST_F16)
|
||||
#define DST_TYPE f16
|
||||
#elif defined(DST_I32)
|
||||
#define DST_TYPE i32
|
||||
#endif
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<{{SRC_TYPE}}>;
|
||||
var<storage, read_write> src: array<SRC_TYPE>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<{{DST_TYPE}}>;
|
||||
var<storage, read_write> dst: array<DST_TYPE>;
|
||||
|
||||
struct Params {
|
||||
ne: u32, // total number of elements
|
||||
offset_src: u32, // in elements
|
||||
offset_dst: u32, // in elements
|
||||
struct Params{
|
||||
ne: u32,
|
||||
offset_src: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
// Strides (in elements) — may be permuted
|
||||
stride_src0: u32,
|
||||
stride_src1: u32,
|
||||
stride_src2: u32,
|
||||
stride_src3: u32,
|
||||
|
||||
|
||||
stride_dst0: u32,
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
// Logical shapes
|
||||
src_ne0: u32,
|
||||
src_ne1: u32,
|
||||
src_ne2: u32,
|
||||
@@ -73,8 +48,7 @@ struct Params {
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
@@ -102,6 +76,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
|
||||
j2 * params.stride_dst2 + j3 * params.stride_dst3;
|
||||
|
||||
dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
|
||||
dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx]));
|
||||
}
|
||||
#end(SHADER)
|
||||
|
||||
@@ -1,41 +1,8 @@
|
||||
import os
|
||||
import re
|
||||
import ast
|
||||
import argparse
|
||||
|
||||
|
||||
def extract_block(text, name):
|
||||
pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError(f"Missing block: {name}")
|
||||
return match.group(1).strip()
|
||||
|
||||
|
||||
def parse_decls(decls_text):
|
||||
decls = {}
|
||||
for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
|
||||
decls[name.strip()] = code.strip()
|
||||
return decls
|
||||
|
||||
|
||||
def replace_repl_placeholders(variant, template_map):
|
||||
for repl, code in variant["REPLS"].items():
|
||||
for key, val in template_map.items():
|
||||
# Match "key" and avoid matching subsequences using by using \b
|
||||
code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code)
|
||||
variant["REPLS"][repl] = code
|
||||
return variant
|
||||
|
||||
|
||||
def replace_placeholders(shader_text, replacements):
|
||||
for key, val in replacements.items():
|
||||
# Match {{KEY}} literally, where KEY is escaped
|
||||
pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
|
||||
shader_text = re.sub(pattern, str(val), shader_text)
|
||||
return shader_text
|
||||
|
||||
|
||||
def expand_includes(shader, input_dir):
|
||||
"""
|
||||
Replace #include "file" lines in the text with the contents of that file.
|
||||
@@ -98,84 +65,24 @@ def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
|
||||
outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n')
|
||||
|
||||
|
||||
def generate_variants(fname, input_dir, output_dir, outfile):
|
||||
shader_path = os.path.join(input_dir, fname)
|
||||
shader_base_name = fname.split(".")[0]
|
||||
|
||||
with open(shader_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
try:
|
||||
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
|
||||
except ValueError:
|
||||
write_shader(shader_base_name, text, output_dir, outfile, input_dir)
|
||||
else:
|
||||
try:
|
||||
decls_map = parse_decls(extract_block(text, "DECLS"))
|
||||
except ValueError:
|
||||
decls_map = {}
|
||||
try:
|
||||
templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES"))
|
||||
except ValueError:
|
||||
templates_map = {}
|
||||
|
||||
for fname in sorted(os.listdir(input_dir)):
|
||||
if fname.endswith(".tmpl"):
|
||||
tmpl_path = os.path.join(input_dir, fname)
|
||||
with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
|
||||
decls = f_tmpl.read()
|
||||
decls_map.update(parse_decls(decls))
|
||||
|
||||
shader_template = extract_block(text, "SHADER")
|
||||
for variant in variants:
|
||||
if "DECLS" in variant:
|
||||
decls = variant["DECLS"]
|
||||
else:
|
||||
decls = []
|
||||
decls_code = ""
|
||||
for key in decls:
|
||||
if key not in decls_map:
|
||||
raise ValueError(f"DECLS key '{key}' not found.")
|
||||
decls_code += decls_map[key] + "\n\n"
|
||||
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
|
||||
if "REPLS" in variant:
|
||||
variant = replace_repl_placeholders(variant, templates_map)
|
||||
final_shader = replace_placeholders(final_shader, variant["REPLS"])
|
||||
# second run to expand placeholders in repl_template
|
||||
final_shader = replace_placeholders(final_shader, variant["REPLS"])
|
||||
final_shader = expand_includes(final_shader, input_dir)
|
||||
|
||||
if "SHADER_NAME" in variant:
|
||||
output_name = variant["SHADER_NAME"]
|
||||
elif "SHADER_SUFFIX" in variant:
|
||||
output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
|
||||
elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
|
||||
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
|
||||
elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
|
||||
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
|
||||
elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
|
||||
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
|
||||
else:
|
||||
output_name = shader_base_name
|
||||
write_shader(output_name, final_shader, output_dir, outfile, input_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_dir", required=True)
|
||||
parser.add_argument("--output_file", required=True)
|
||||
parser.add_argument("--output_dir")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.output_dir:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
with open(args.output_file, "w", encoding="utf-8") as out:
|
||||
out.write("// Auto-generated shader embedding\n")
|
||||
out.write("#include <string>\n\n")
|
||||
for fname in sorted(os.listdir(args.input_dir)):
|
||||
if fname.endswith(".wgsl"):
|
||||
generate_variants(fname, args.input_dir, args.output_dir, out)
|
||||
shader_path = os.path.join(args.input_dir, fname)
|
||||
shader_name = fname.replace(".wgsl", "")
|
||||
|
||||
with open(shader_path, "r", encoding="utf-8") as f:
|
||||
shader_code = f.read()
|
||||
|
||||
write_shader(shader_name, shader_code, None, out, args.input_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -6,6 +6,8 @@ enable chromium_experimental_subgroup_matrix;
|
||||
|
||||
#ifdef KV_F32
|
||||
#define KV_TYPE f32
|
||||
#elif defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
#define KV_TYPE u32
|
||||
#else
|
||||
#define KV_TYPE f16
|
||||
#endif
|
||||
@@ -37,11 +39,13 @@ enable chromium_experimental_subgroup_matrix;
|
||||
#define NQ 16
|
||||
// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
|
||||
#define F16_PER_BLOCK 9
|
||||
#define BLOCK_SIZE_BYTES 18u
|
||||
#define WEIGHTS_PER_F16 4
|
||||
#elif defined(KV_Q8_0)
|
||||
#define NQ 8
|
||||
// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
|
||||
#define F16_PER_BLOCK 17
|
||||
#define BLOCK_SIZE_BYTES 34u
|
||||
#define WEIGHTS_PER_F16 2
|
||||
#endif
|
||||
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
|
||||
@@ -55,6 +59,47 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
|
||||
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
|
||||
}
|
||||
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
fn load_k_u16_at(byte_offset: u32) -> u32 {
|
||||
let word = K[byte_offset / 4u];
|
||||
let shift = (byte_offset & 2u) * 8u;
|
||||
return (word >> shift) & 0xFFFFu;
|
||||
}
|
||||
|
||||
fn load_k_u32_at(byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4u;
|
||||
let shift = (byte_offset & 3u) * 8u;
|
||||
let lo = K[word_idx];
|
||||
if (shift == 0u) {
|
||||
return lo;
|
||||
}
|
||||
let hi = K[word_idx + 1u];
|
||||
return (lo >> shift) | (hi << (32u - shift));
|
||||
}
|
||||
|
||||
fn load_v_u16_at(byte_offset: u32) -> u32 {
|
||||
let word = V[byte_offset / 4u];
|
||||
let shift = (byte_offset & 2u) * 8u;
|
||||
return (word >> shift) & 0xFFFFu;
|
||||
}
|
||||
|
||||
fn load_v_u32_at(byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4u;
|
||||
let shift = (byte_offset & 3u) * 8u;
|
||||
let lo = V[word_idx];
|
||||
if (shift == 0u) {
|
||||
return lo;
|
||||
}
|
||||
let hi = V[word_idx + 1u];
|
||||
return (lo >> shift) | (hi << (32u - shift));
|
||||
}
|
||||
|
||||
fn f16_from_u16(bits: u32) -> f16 {
|
||||
let packed = unpack2x16float(bits);
|
||||
return f16(packed[0]);
|
||||
}
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_q: u32,
|
||||
offset_k: u32,
|
||||
@@ -254,12 +299,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = K[base_idx]; // scale
|
||||
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
|
||||
let d = f16_from_u16(load_k_u16_at(block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = K[base_idx + 1u + block_offset + j];
|
||||
let q_1 = K[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_k_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
@@ -282,12 +326,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = K[base_idx]; // scale
|
||||
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
|
||||
let d = f16_from_u16(load_k_u16_at(block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = K[base_idx + 1u + block_offset + j];
|
||||
let q_1 = K[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_k_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
@@ -459,12 +502,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = V[base_idx]; // scale
|
||||
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
|
||||
let d = f16_from_u16(load_v_u16_at(block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = V[base_idx + 1u + block_offset + j];
|
||||
let q_1 = V[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_v_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
@@ -487,12 +529,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = V[base_idx]; // scale
|
||||
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
|
||||
let d = f16_from_u16(load_v_u16_at(block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = V[base_idx + 1u + block_offset + j];
|
||||
let q_1 = V[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_v_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
|
||||
105
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl
Normal file
105
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl
Normal file
@@ -0,0 +1,105 @@
|
||||
diagnostic(off, subgroup_uniformity);
|
||||
enable f16;
|
||||
|
||||
#define Q_TILE 1
|
||||
#define KV_TILE 32
|
||||
#define WG_SIZE 32
|
||||
|
||||
struct Params {
|
||||
offset_mask: u32,
|
||||
seq_len_q: u32,
|
||||
seq_len_kv: u32,
|
||||
stride_mask3: u32,
|
||||
// Number of KV blocks and Q blocks per batch.
|
||||
// nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE).
|
||||
nblk0: u32,
|
||||
nblk1: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read> mask: array<f16>;
|
||||
@group(0) @binding(1) var<storage, read_write> blk: array<u32>;
|
||||
@group(0) @binding(2) var<uniform> params: Params;
|
||||
|
||||
const MASK_MIN: f32 = -65504.0;
|
||||
const MASK_MAX: f32 = 65504.0;
|
||||
var<workgroup> wg_min: array<f32, WG_SIZE>;
|
||||
var<workgroup> wg_max: array<f32, WG_SIZE>;
|
||||
var<workgroup> wg_any: array<u32, WG_SIZE>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>) {
|
||||
// Dispatch mapping:
|
||||
// - x indexes KV blocks
|
||||
// - y flattens (batch_idx, q_blk) as y = batch_idx * nblk1 + q_blk
|
||||
let kv_blk = wg_id.x;
|
||||
let y = wg_id.y;
|
||||
let q_blk = y % params.nblk1;
|
||||
let batch_idx = y / params.nblk1;
|
||||
if (kv_blk >= params.nblk0) {
|
||||
return;
|
||||
}
|
||||
|
||||
let q_start = q_blk * Q_TILE;
|
||||
let k_start = kv_blk * KV_TILE;
|
||||
|
||||
let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
|
||||
let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3;
|
||||
|
||||
// We keep min/max to classify:
|
||||
// - fully masked (max <= MASK_MIN)
|
||||
// - all-zero mask (min == 0 && max == 0)
|
||||
// - mixed/general mask
|
||||
var local_min = MASK_MAX;
|
||||
var local_max = -MASK_MAX;
|
||||
var local_any = 0u;
|
||||
|
||||
for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) {
|
||||
let q_row = q_start + q_rel;
|
||||
if (q_row >= params.seq_len_q) {
|
||||
continue;
|
||||
}
|
||||
let row_base = mask_batch_base + q_row * params.seq_len_kv;
|
||||
for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) {
|
||||
let k_col = k_start + k_rel;
|
||||
if (k_col >= params.seq_len_kv) {
|
||||
continue;
|
||||
}
|
||||
let mv = f32(mask[row_base + k_col]);
|
||||
local_min = min(local_min, mv);
|
||||
local_max = max(local_max, mv);
|
||||
local_any = 1u;
|
||||
}
|
||||
}
|
||||
|
||||
wg_min[local_id.x] = local_min;
|
||||
wg_max[local_id.x] = local_max;
|
||||
wg_any[local_id.x] = local_any;
|
||||
workgroupBarrier();
|
||||
|
||||
// Thread 0 writes one state per block.
|
||||
if (local_id.x == 0u) {
|
||||
var mmin = wg_min[0];
|
||||
var mmax = wg_max[0];
|
||||
var many = wg_any[0];
|
||||
for (var i = 1u; i < WG_SIZE; i += 1u) {
|
||||
mmin = min(mmin, wg_min[i]);
|
||||
mmax = max(mmax, wg_max[i]);
|
||||
many = max(many, wg_any[i]);
|
||||
}
|
||||
|
||||
var state = 0u;
|
||||
if (many != 0u) {
|
||||
if (mmax <= MASK_MIN) {
|
||||
state = 0u;
|
||||
} else if (mmin == 0.0 && mmax == 0.0) {
|
||||
state = 2u;
|
||||
} else {
|
||||
state = 1u;
|
||||
}
|
||||
}
|
||||
|
||||
let blk_idx = (batch_idx * params.nblk1 + q_blk) * params.nblk0 + kv_blk;
|
||||
blk[blk_idx] = state;
|
||||
}
|
||||
}
|
||||
78
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl
Normal file
78
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl
Normal file
@@ -0,0 +1,78 @@
|
||||
diagnostic(off, subgroup_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
|
||||
// Default values
|
||||
#define HEAD_DIM_V 64
|
||||
#define WG_SIZE 128
|
||||
|
||||
struct Params {
|
||||
nrows: u32,
|
||||
seq_len_q: u32,
|
||||
n_heads: u32,
|
||||
offset_dst: u32,
|
||||
nwg: u32,
|
||||
tmp_data_base: u32,
|
||||
tmp_stats_base: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> tmp: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read_write> dst: array<vec4<f32>>;
|
||||
@group(0) @binding(2) var<uniform> params: Params;
|
||||
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||
let rid = wg_id.x;
|
||||
if (rid >= params.nrows) {
|
||||
return;
|
||||
}
|
||||
|
||||
let rows_per_batch = params.n_heads * params.seq_len_q;
|
||||
let batch_idx = rid / rows_per_batch;
|
||||
let rem = rid % rows_per_batch;
|
||||
let head_idx = rem / params.seq_len_q;
|
||||
let q_row = rem % params.seq_len_q;
|
||||
|
||||
let dst2_stride = HEAD_DIM_V * params.n_heads;
|
||||
let dst3_stride = dst2_stride * params.seq_len_q;
|
||||
let row_base = params.offset_dst + batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V;
|
||||
|
||||
let thread = sg_inv_id;
|
||||
if (params.nwg > subgroup_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
let stats_base = params.tmp_stats_base + rid * (2u * params.nwg);
|
||||
let active_thread = thread < params.nwg;
|
||||
let si = select(0.0, tmp[stats_base + 2u * thread + 0u], active_thread);
|
||||
let mi = select(FLOAT_MIN, tmp[stats_base + 2u * thread + 1u], active_thread);
|
||||
let m = subgroupMax(mi);
|
||||
let ms = select(0.0, exp(mi - m), active_thread);
|
||||
let s = subgroupAdd(si * ms);
|
||||
let inv_s = select(0.0, 1.0 / s, s != 0.0);
|
||||
|
||||
let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg);
|
||||
for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) {
|
||||
var weighted = vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
if (active_thread) {
|
||||
let src = row_tmp_base + thread * HEAD_DIM_V + elem_base;
|
||||
weighted = vec4<f32>(tmp[src + 0u], tmp[src + 1u], tmp[src + 2u], tmp[src + 3u]) * ms;
|
||||
}
|
||||
|
||||
let sum_x = subgroupAdd(weighted.x);
|
||||
let sum_y = subgroupAdd(weighted.y);
|
||||
let sum_z = subgroupAdd(weighted.z);
|
||||
let sum_w = subgroupAdd(weighted.w);
|
||||
|
||||
if (thread == 0u) {
|
||||
let dst_vec_index = (row_base + elem_base) >> 2u;
|
||||
dst[dst_vec_index] = vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s;
|
||||
}
|
||||
}
|
||||
}
|
||||
729
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl
Normal file
729
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl
Normal file
@@ -0,0 +1,729 @@
|
||||
diagnostic(off, chromium.subgroup_matrix_uniformity);
|
||||
diagnostic(off, subgroup_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
enable chromium_experimental_subgroup_matrix;
|
||||
|
||||
#ifdef KV_F32
|
||||
#define KV_TYPE f32
|
||||
#else
|
||||
#define KV_TYPE f16
|
||||
#endif
|
||||
|
||||
#define HEAD_DIM_QK 64
|
||||
#define HEAD_DIM_V 64
|
||||
|
||||
|
||||
#define SG_MAT_M 8
|
||||
#define SG_MAT_N 8
|
||||
#define SG_MAT_K 8
|
||||
|
||||
#define Q_TILE SG_MAT_M
|
||||
#define KV_TILE 16
|
||||
#define WG_SIZE 64
|
||||
#ifndef VEC_NE
|
||||
#define VEC_NE 4u
|
||||
#endif
|
||||
|
||||
#define KV_BLOCKS (KV_TILE / SG_MAT_N)
|
||||
|
||||
#define BLOCK_SIZE 32
|
||||
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
#if defined(KV_Q4_0)
|
||||
#define NQ 16
|
||||
#define F16_PER_BLOCK 9
|
||||
#define WEIGHTS_PER_F16 4
|
||||
#elif defined(KV_Q8_0)
|
||||
#define NQ 8
|
||||
#define F16_PER_BLOCK 17
|
||||
#define WEIGHTS_PER_F16 2
|
||||
#endif
|
||||
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
|
||||
|
||||
fn get_byte(value: u32, index: u32) -> u32 {
|
||||
return (value >> (index * 8)) & 0xFF;
|
||||
}
|
||||
|
||||
fn get_byte_i32(value: u32, index: u32) -> i32 {
|
||||
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
|
||||
}
|
||||
|
||||
struct Params {
|
||||
offset_q: u32,
|
||||
offset_k: u32,
|
||||
offset_v: u32,
|
||||
offset_mask: u32,
|
||||
offset_sinks: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
// shapes of Q/K/V
|
||||
n_heads: u32,
|
||||
seq_len_q: u32,
|
||||
seq_len_kv: u32,
|
||||
|
||||
// strides (in elements)
|
||||
stride_q1: u32,
|
||||
stride_q2: u32,
|
||||
stride_q3: u32,
|
||||
stride_k1: u32,
|
||||
stride_k2: u32,
|
||||
stride_k3: u32,
|
||||
stride_v1: u32,
|
||||
stride_v2: u32,
|
||||
stride_v3: u32,
|
||||
stride_mask3: u32,
|
||||
|
||||
// repeat factors for K/V, e.g., MHA vs. MQA vs. GQA
|
||||
q_per_kv: u32,
|
||||
|
||||
// softmax params
|
||||
scale: f32,
|
||||
max_bias: f32,
|
||||
logit_softcap: f32,
|
||||
n_head_log2: f32,
|
||||
m0: f32,
|
||||
m1: f32,
|
||||
|
||||
#ifdef BLK
|
||||
blk_base: u32,
|
||||
blk_nblk0: u32,
|
||||
blk_nblk1: u32,
|
||||
#endif
|
||||
|
||||
tmp_data_base: u32,
|
||||
tmp_stats_base: u32,
|
||||
nwg: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
|
||||
#endif
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
|
||||
#else
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
|
||||
#endif
|
||||
#if defined(MASK) && defined(SINKS)
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||
#ifdef BLK
|
||||
#define BLK_BINDING 5
|
||||
#define TMP_BINDING 6
|
||||
#define DST_BINDING 7
|
||||
#define PARAMS_BINDING 8
|
||||
#else
|
||||
#define TMP_BINDING 5
|
||||
#define DST_BINDING 6
|
||||
#define PARAMS_BINDING 7
|
||||
#endif
|
||||
#elif defined(MASK)
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
#ifdef BLK
|
||||
#define BLK_BINDING 4
|
||||
#define TMP_BINDING 5
|
||||
#define DST_BINDING 6
|
||||
#define PARAMS_BINDING 7
|
||||
#else
|
||||
#define TMP_BINDING 4
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#endif
|
||||
#elif defined(SINKS)
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#define TMP_BINDING 4
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#else
|
||||
#define TMP_BINDING 3
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
|
||||
#ifdef BLK
|
||||
@group(0) @binding(BLK_BINDING) var<storage, read_write> blk: array<u32>;
|
||||
#endif
|
||||
@group(0) @binding(TMP_BINDING) var<storage, read_write> tmp: array<f32>;
|
||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
|
||||
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
|
||||
|
||||
// Just a very small float value.
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
|
||||
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
|
||||
|
||||
#ifndef KV_DIRECT
|
||||
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
|
||||
// we can reuse the same shmem for K and V since we only need one at a time
|
||||
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
|
||||
#endif
|
||||
|
||||
var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>;
|
||||
|
||||
#ifdef MASK
|
||||
// storage for mask values
|
||||
var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
|
||||
#endif
|
||||
|
||||
// note that we reuse the same storage for both since we only need one at a time
|
||||
var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
|
||||
|
||||
// Storage for row max and exp sum during online softmax
|
||||
var<workgroup> row_max_shmem: array<f32, Q_TILE>;
|
||||
var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
|
||||
var<workgroup> blk_state_wg: u32;
|
||||
|
||||
fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
|
||||
var v = select(FLOAT_MIN,
|
||||
f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
|
||||
kv_idx < KV_TILE);
|
||||
#ifdef LOGIT_SOFTCAP
|
||||
v = params.logit_softcap * tanh(v);
|
||||
#endif
|
||||
#ifdef MASK
|
||||
if (apply_mask) {
|
||||
var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
|
||||
v += select(mask_val, slope * mask_val, has_bias);
|
||||
}
|
||||
#endif
|
||||
return v;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||
|
||||
// initialize row max for online softmax
|
||||
for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
|
||||
row_max_shmem[i] = FLOAT_MIN;
|
||||
exp_sum_shmem[i] = 0.0;
|
||||
}
|
||||
|
||||
for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
|
||||
o_shmem[i] = 0.0;
|
||||
}
|
||||
|
||||
// workgroups per head/batch
|
||||
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
|
||||
let wg_per_batch = wg_per_head * params.n_heads;
|
||||
|
||||
let dst2_stride = HEAD_DIM_V * params.n_heads;
|
||||
let dst3_stride = dst2_stride * params.seq_len_q;
|
||||
|
||||
let iwg = wg_id.x % params.nwg;
|
||||
let base_wg_id = wg_id.x / params.nwg;
|
||||
|
||||
// batch index
|
||||
let batch_idx = base_wg_id / wg_per_batch;
|
||||
let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
|
||||
let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
|
||||
let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
|
||||
let wg_in_batch = base_wg_id % wg_per_batch;
|
||||
|
||||
// head index
|
||||
let head_idx = wg_in_batch / wg_per_head;
|
||||
let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
|
||||
let k_head_idx = head_idx / params.q_per_kv;
|
||||
let v_head_idx = k_head_idx;
|
||||
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
|
||||
let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
|
||||
|
||||
// starting Q row for this workgroup
|
||||
let wg_in_head = wg_in_batch % wg_per_head;
|
||||
let q_row_start = wg_in_head * Q_TILE;
|
||||
|
||||
#ifdef MASK
|
||||
// mask offset
|
||||
let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
|
||||
#endif
|
||||
|
||||
let head = f32(head_idx);
|
||||
let has_bias = params.max_bias > 0.0;
|
||||
let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias);
|
||||
|
||||
// load q tile into shared memory
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let q_row = elem_idx / HEAD_DIM_QK;
|
||||
let q_col = elem_idx % HEAD_DIM_QK;
|
||||
let head_q_row = q_row_start + q_row;
|
||||
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
|
||||
q_shmem[elem_idx] = f16(select(
|
||||
0.0,
|
||||
Q[global_q_row_offset + q_col],
|
||||
head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
|
||||
}
|
||||
|
||||
for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
|
||||
#ifdef BLK
|
||||
let q_blk = q_row_start / Q_TILE;
|
||||
let kv_blk = kv_tile / KV_TILE;
|
||||
let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
|
||||
let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk;
|
||||
let blk_state_local = blk[blk_idx];
|
||||
#else
|
||||
let blk_state_local = 1u;
|
||||
#endif
|
||||
if (local_id.x == 0u) {
|
||||
blk_state_wg = blk_state_local;
|
||||
}
|
||||
workgroupBarrier();
|
||||
let blk_state = blk_state_wg;
|
||||
let skip_tile = blk_state == 0u;
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
inter_shmem[elem_idx] = f16(0.0);
|
||||
}
|
||||
|
||||
// load k tile into shared memory
|
||||
#if defined(KV_Q4_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let k_row = blck_idx / BLOCKS_K;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let row_offset = k_row * HEAD_DIM_QK;
|
||||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = K[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = K[base_idx + 1u + block_offset + j];
|
||||
let q_1 = K[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_lo;
|
||||
kv_shmem[row_offset + idx + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_Q8_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let k_row = blck_idx / BLOCKS_K;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let row_offset = k_row * HEAD_DIM_QK;
|
||||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = K[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = K[base_idx + 1u + block_offset + j];
|
||||
let q_1 = K[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_DIRECT)
|
||||
// Direct global loads for KV
|
||||
#else
|
||||
for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) {
|
||||
let k_row = elem_idx / HEAD_DIM_QK;
|
||||
let k_col = elem_idx % HEAD_DIM_QK;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
|
||||
let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK;
|
||||
let vec_idx = (global_k_row_offset + k_col) >> 2u;
|
||||
let k4 = select(vec4<KV_TYPE>(0.0), K[vec_idx], in_bounds);
|
||||
kv_shmem[elem_idx + 0u] = f16(k4.x);
|
||||
kv_shmem[elem_idx + 1u] = f16(k4.y);
|
||||
kv_shmem[elem_idx + 2u] = f16(k4.z);
|
||||
kv_shmem[elem_idx + 3u] = f16(k4.w);
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// accumulate q block * k block into registers across the entire KV tile
|
||||
if (!skip_tile) {
|
||||
let num_of_threads = subgroup_size / VEC_NE;
|
||||
let tx = sg_inv_id % num_of_threads;
|
||||
let ty = sg_inv_id / num_of_threads;
|
||||
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
continue;
|
||||
}
|
||||
let local_q_row_offset = q_tile_row * HEAD_DIM_QK;
|
||||
|
||||
for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) {
|
||||
let kv_idx = kv_base + ty;
|
||||
var partial_sum: f32 = 0.0;
|
||||
let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv;
|
||||
if (kv_valid) {
|
||||
for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) {
|
||||
let q_off = local_q_row_offset + i * 4u;
|
||||
|
||||
let qv = vec4<f32>(
|
||||
f32(q_shmem[q_off + 0u]),
|
||||
f32(q_shmem[q_off + 1u]),
|
||||
f32(q_shmem[q_off + 2u]),
|
||||
f32(q_shmem[q_off + 3u]));
|
||||
#ifdef KV_DIRECT
|
||||
let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u);
|
||||
let kv = vec4<f32>(K[idx >> 2u]);
|
||||
#else
|
||||
let idx = kv_idx * HEAD_DIM_QK + (i * 4u);
|
||||
let kv = vec4<f32>(
|
||||
f32(kv_shmem[idx + 0u]),
|
||||
f32(kv_shmem[idx + 1u]),
|
||||
f32(kv_shmem[idx + 2u]),
|
||||
f32(kv_shmem[idx + 3u]));
|
||||
#endif
|
||||
partial_sum += dot(qv, kv);
|
||||
}
|
||||
}
|
||||
var sum = partial_sum;
|
||||
// Reduce over tx threads (NL) for this ty stripe.
|
||||
var tx_delta = num_of_threads >> 1u;
|
||||
loop {
|
||||
if (tx_delta == 0u) {
|
||||
break;
|
||||
}
|
||||
let sh = subgroupShuffleDown(sum, tx_delta);
|
||||
if (tx < tx_delta) {
|
||||
sum += sh;
|
||||
}
|
||||
tx_delta >>= 1u;
|
||||
}
|
||||
|
||||
let sum_bcast = subgroupShuffle(sum, num_of_threads * ty);
|
||||
if (tx == 0u && kv_valid) {
|
||||
let dst_idx = q_tile_row * KV_TILE + kv_idx;
|
||||
inter_shmem[dst_idx] = f16(sum_bcast);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#ifdef MASK
|
||||
let apply_mask = !skip_tile && (blk_state != 2u);
|
||||
if (apply_mask) {
|
||||
// load mask tile into shared memory for this KV block
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
let mask_row = elem_idx / KV_TILE;
|
||||
let mask_col = elem_idx % KV_TILE;
|
||||
let global_q_row = q_row_start + mask_row;
|
||||
let global_k_col = kv_tile + mask_col;
|
||||
let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
|
||||
let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
|
||||
mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
|
||||
}
|
||||
}
|
||||
#else
|
||||
let apply_mask = false;
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// online softmax
|
||||
if (!skip_tile) {
|
||||
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
break;
|
||||
}
|
||||
|
||||
var prev_max = row_max_shmem[q_tile_row];
|
||||
var final_max = prev_max;
|
||||
// pass 1: compute final max across the full KV tile in chunks
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE;
|
||||
let softmax_term = select(FLOAT_MIN,
|
||||
calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask),
|
||||
kv_valid);
|
||||
final_max = subgroupMax(max(final_max, softmax_term));
|
||||
}
|
||||
|
||||
var total_exp_term: f32 = 0.0;
|
||||
// pass 2: compute exp sum and write P using final_max
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask);
|
||||
let cur_p = select(0.0,
|
||||
exp(softmax_term - final_max),
|
||||
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
|
||||
total_exp_term += subgroupAdd(cur_p);
|
||||
if (kv_idx < KV_TILE) {
|
||||
inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
|
||||
}
|
||||
}
|
||||
|
||||
let cur_exp = exp(prev_max - final_max);
|
||||
|
||||
if (sg_inv_id == 0) {
|
||||
row_max_shmem[q_tile_row] = final_max;
|
||||
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
|
||||
}
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
|
||||
o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// load v tile into shared memory
|
||||
#if defined(KV_Q4_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let v_row = blck_idx / BLOCKS_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let block_k = blck_idx % BLOCKS_V;
|
||||
let row_offset = v_row * HEAD_DIM_V;
|
||||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = V[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = V[base_idx + 1u + block_offset + j];
|
||||
let q_1 = V[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_lo;
|
||||
kv_shmem[row_offset + idx + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_Q8_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let v_row = blck_idx / BLOCKS_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let block_k = blck_idx % BLOCKS_V;
|
||||
let row_offset = v_row * HEAD_DIM_V;
|
||||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = V[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = V[base_idx + 1u + block_offset + j];
|
||||
let q_1 = V[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_DIRECT)
|
||||
// Direct global loads for KV
|
||||
#else
|
||||
for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) {
|
||||
let v_row = elem_idx / HEAD_DIM_V;
|
||||
let v_col = elem_idx % HEAD_DIM_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
|
||||
let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V;
|
||||
let vec_idx = (global_v_row_offset + v_col) >> 2u;
|
||||
let v4 = select(vec4<KV_TYPE>(0.0), V[vec_idx], in_bounds);
|
||||
kv_shmem[elem_idx + 0u] = f16(v4.x);
|
||||
kv_shmem[elem_idx + 1u] = f16(v4.y);
|
||||
kv_shmem[elem_idx + 2u] = f16(v4.z);
|
||||
kv_shmem[elem_idx + 3u] = f16(v4.w);
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (!skip_tile) {
|
||||
// we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
|
||||
// we want to compute O += P * V across the full KV tile
|
||||
let ne_threads : u32 = VEC_NE;
|
||||
let nl_threads = max(1u, subgroup_size / ne_threads);
|
||||
let tx_pv = sg_inv_id % nl_threads;
|
||||
let ty_pv = sg_inv_id / nl_threads;
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) {
|
||||
var lo = vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) {
|
||||
let kv_idx = cc * ne_threads + ty_pv;
|
||||
let v_row = kv_tile + kv_idx;
|
||||
if (v_row >= params.seq_len_kv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]);
|
||||
#ifdef KV_DIRECT
|
||||
let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u;
|
||||
let v4 = vec4<f32>(V[v_idx >> 2u]);
|
||||
#else
|
||||
let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u;
|
||||
let v4 = vec4<f32>(
|
||||
f32(kv_shmem[v_idx + 0u]),
|
||||
f32(kv_shmem[v_idx + 1u]),
|
||||
f32(kv_shmem[v_idx + 2u]),
|
||||
f32(kv_shmem[v_idx + 3u]));
|
||||
#endif
|
||||
lo += p * v4;
|
||||
}
|
||||
|
||||
var lo_x = lo.x;
|
||||
var lo_y = lo.y;
|
||||
var lo_z = lo.z;
|
||||
var lo_w = lo.w;
|
||||
// Reduce over ty threads (NE) for this tx thread.
|
||||
var ty_delta = ne_threads >> 1u;
|
||||
loop {
|
||||
if (ty_delta == 0u) {
|
||||
break;
|
||||
}
|
||||
let thread_delta = ty_delta * nl_threads;
|
||||
let shx = subgroupShuffleDown(lo_x, thread_delta);
|
||||
let shy = subgroupShuffleDown(lo_y, thread_delta);
|
||||
let shz = subgroupShuffleDown(lo_z, thread_delta);
|
||||
let shw = subgroupShuffleDown(lo_w, thread_delta);
|
||||
if (ty_pv < ty_delta) {
|
||||
lo_x += shx;
|
||||
lo_y += shy;
|
||||
lo_z += shz;
|
||||
lo_w += shw;
|
||||
}
|
||||
ty_delta >>= 1u;
|
||||
}
|
||||
|
||||
if (ty_pv == 0u) {
|
||||
let elem_base = vec_col * 4u;
|
||||
let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base;
|
||||
o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x);
|
||||
o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y);
|
||||
o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z);
|
||||
o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
|
||||
#ifdef SINKS
|
||||
// Sinks are global terms and must be applied exactly once across split workgroups.
|
||||
if (iwg == 0u) {
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
break;
|
||||
}
|
||||
|
||||
var prev_max = row_max_shmem[q_tile_row];
|
||||
|
||||
// for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
|
||||
let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
|
||||
let new_max = subgroupMax(max(prev_max, sink_val));
|
||||
let max_exp = exp(prev_max - new_max);
|
||||
let sink_exp = exp(sink_val - new_max);
|
||||
|
||||
let sink_exp_sum = subgroupAdd(sink_exp);
|
||||
|
||||
if (sg_inv_id == 0) {
|
||||
row_max_shmem[q_tile_row] = new_max;
|
||||
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
|
||||
}
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
|
||||
o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp);
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
#endif
|
||||
let rows_per_batch = params.n_heads * params.seq_len_q;
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) { break; }
|
||||
|
||||
if (params.nwg == 1u) {
|
||||
let exp_sum = exp_sum_shmem[q_tile_row];
|
||||
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
|
||||
let row_base: u32 =
|
||||
params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V;
|
||||
|
||||
for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) {
|
||||
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
|
||||
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
|
||||
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
|
||||
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
|
||||
|
||||
let v = vec4<f32>(
|
||||
f32(o_shmem[i0]) * scale,
|
||||
f32(o_shmem[i1]) * scale,
|
||||
f32(o_shmem[i2]) * scale,
|
||||
f32(o_shmem[i3]) * scale
|
||||
);
|
||||
|
||||
let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
|
||||
dst[dst_vec_index] = v;
|
||||
}
|
||||
} else {
|
||||
let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row;
|
||||
let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V;
|
||||
let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg;
|
||||
|
||||
for (var elem_base = sg_inv_id * 4u;
|
||||
elem_base < HEAD_DIM_V;
|
||||
elem_base += subgroup_size * 4u) {
|
||||
|
||||
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
|
||||
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
|
||||
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
|
||||
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
|
||||
|
||||
let tbase = tmp_row_data_base + elem_base;
|
||||
tmp[tbase + 0u] = f32(o_shmem[i0]);
|
||||
tmp[tbase + 1u] = f32(o_shmem[i1]);
|
||||
tmp[tbase + 2u] = f32(o_shmem[i2]);
|
||||
tmp[tbase + 3u] = f32(o_shmem[i3]);
|
||||
}
|
||||
|
||||
if (sg_inv_id == 0u) {
|
||||
tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row];
|
||||
tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,323 +0,0 @@
|
||||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"SHADER_NAME": "reglu_f32",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["NO_SPLIT", "REGLU"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "reglu_f32_split",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["SPLIT", "REGLU"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "reglu_f16",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["NO_SPLIT", "REGLU"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "reglu_f16_split",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["SPLIT", "REGLU"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "geglu_f32",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["NO_SPLIT", "GEGLU"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "geglu_f32_split",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["SPLIT", "GEGLU"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "geglu_f16",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["NO_SPLIT", "GEGLU"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "geglu_f16_split",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["SPLIT", "GEGLU"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "swiglu_f32",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["NO_SPLIT", "SWIGLU"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "swiglu_f32_split",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["SPLIT", "SWIGLU"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "swiglu_f16",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["NO_SPLIT", "SWIGLU"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "swiglu_f16_split",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["SPLIT", "SWIGLU"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "swiglu_oai_f32",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["NO_SPLIT", "SWIGLU_OAI"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "swiglu_oai_f32_split",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["SPLIT", "SWIGLU_OAI"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "geglu_erf_f32",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["NO_SPLIT", "GEGLU_ERF"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "geglu_erf_f32_split",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["SPLIT", "GEGLU_ERF"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "geglu_erf_f16",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["NO_SPLIT", "GEGLU_ERF"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "geglu_erf_f16_split",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["SPLIT", "GEGLU_ERF"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "geglu_quick_f32",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "geglu_quick_f32_split",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["SPLIT", "GEGLU_QUICK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "geglu_quick_f16",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "geglu_quick_f16_split",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["SPLIT", "GEGLU_QUICK"]
|
||||
},
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(REGLU)
|
||||
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||
return max(a, 0) * b;
|
||||
}
|
||||
#enddecl(REGLU)
|
||||
|
||||
#decl(GEGLU)
|
||||
const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876;
|
||||
const GELU_COEF_A: {{TYPE}} = 0.044715;
|
||||
|
||||
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||
let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
|
||||
return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b;
|
||||
}
|
||||
#enddecl(GEGLU)
|
||||
|
||||
#decl(SWIGLU)
|
||||
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||
return a / (1.0 + exp(-a)) * b;
|
||||
}
|
||||
#enddecl(SWIGLU)
|
||||
|
||||
#decl(SWIGLU_OAI)
|
||||
fn op(a: f32, b: f32) -> f32 {
|
||||
let xi = min(a, params.limit);
|
||||
let gi = max(min(b, params.limit), -params.limit);
|
||||
var out_glu = xi / (1.0 + exp(-xi * params.alpha));
|
||||
out_glu = out_glu * (1.0 + gi);
|
||||
return out_glu;
|
||||
}
|
||||
#enddecl(SWIGLU_OAI)
|
||||
|
||||
#decl(GEGLU_ERF)
|
||||
const p_erf: {{TYPE}} = 0.3275911;
|
||||
const a1_erf: {{TYPE}} = 0.254829592;
|
||||
const a2_erf: {{TYPE}} = -0.284496736;
|
||||
const a3_erf: {{TYPE}} = 1.421413741;
|
||||
const a4_erf: {{TYPE}} = -1.453152027;
|
||||
const a5_erf: {{TYPE}} = 1.061405429;
|
||||
const SQRT_2_INV: {{TYPE}} = 0.7071067811865476;
|
||||
|
||||
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||
let a_div_sqr2 = a * SQRT_2_INV;
|
||||
let sign_x = sign(a_div_sqr2);
|
||||
let x = abs(a_div_sqr2);
|
||||
let t = 1.0 / (1.0 + p_erf * x);
|
||||
let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
|
||||
let erf_approx = sign_x * y;
|
||||
return 0.5 * a * (1.0 + erf_approx) * b;
|
||||
}
|
||||
#enddecl(GEGLU_ERF)
|
||||
|
||||
#decl(GEGLU_QUICK)
|
||||
const GELU_QUICK_COEF: {{TYPE}} = -1.702;
|
||||
|
||||
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||
return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
|
||||
}
|
||||
#enddecl(GEGLU_QUICK)
|
||||
|
||||
#decl(NO_SPLIT)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn a_value(base: u32) -> {{TYPE}} {
|
||||
let offset: u32 = select(0, params.ne0, params.swapped != 0);
|
||||
return src0[base + offset];
|
||||
}
|
||||
|
||||
fn b_value(base: u32) -> {{TYPE}} {
|
||||
let offset: u32 = select(params.ne0, 0, params.swapped != 0);
|
||||
return src0[base + offset];
|
||||
}
|
||||
#enddecl(NO_SPLIT)
|
||||
|
||||
#decl(SPLIT)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn a_value(base: u32) -> {{TYPE}} {
|
||||
return src0[base];
|
||||
}
|
||||
|
||||
fn b_value(base: u32) -> {{TYPE}} {
|
||||
return src1[base];
|
||||
}
|
||||
#enddecl(SPLIT)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
// Strides (in elements)
|
||||
stride_src01: u32,
|
||||
stride_src02: u32,
|
||||
stride_src03: u32,
|
||||
|
||||
stride_src11: u32,
|
||||
stride_src12: u32,
|
||||
stride_src13: u32,
|
||||
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
// shape of dst
|
||||
ne: u32,
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
|
||||
swapped: u32,
|
||||
alpha: f32,
|
||||
limit: f32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||
|
||||
DECLS
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
var i = gid.x;
|
||||
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||
let i2 = i / (params.ne1 * params.ne0);
|
||||
i = i % (params.ne1 * params.ne0);
|
||||
let i1 = i / params.ne0;
|
||||
let i0 = i % params.ne0;
|
||||
|
||||
let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
|
||||
let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
|
||||
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
|
||||
|
||||
dst[i_dst] = op(a_value(i_a), b_value(i_b));
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
155
ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl
Normal file
155
ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl
Normal file
@@ -0,0 +1,155 @@
|
||||
enable f16;
|
||||
|
||||
#ifdef TYPE_F32
|
||||
#define DataType f32
|
||||
#endif
|
||||
#ifdef TYPE_F16
|
||||
#define DataType f16
|
||||
#endif
|
||||
|
||||
#ifdef OP_REGLU
|
||||
fn op(a: DataType, b: DataType) -> DataType {
|
||||
return max(a, 0) * b;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef OP_GEGLU
|
||||
const SQRT_2_OVER_PI: DataType = 0.79788456080286535587989211986876;
|
||||
const GELU_COEF_A: DataType = 0.044715;
|
||||
|
||||
fn op(a: DataType, b: DataType) -> DataType {
|
||||
let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
|
||||
return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef OP_SWIGLU
|
||||
fn op(a: DataType, b: DataType) -> DataType {
|
||||
return a / (1.0 + exp(-a)) * b;
|
||||
}
|
||||
#endif
|
||||
#ifdef OP_SWIGLU_OAI
|
||||
fn op(a: f32, b: f32) -> f32 {
|
||||
let xi = min(a, params.limit);
|
||||
let gi = max(min(b, params.limit), -params.limit);
|
||||
var out_glu = xi / (1.0 + exp(-xi * params.alpha));
|
||||
out_glu = out_glu * (1.0 + gi);
|
||||
return out_glu;
|
||||
}
|
||||
#endif
|
||||
#ifdef OP_GEGLU_ERF
|
||||
const p_erf: DataType = 0.3275911;
|
||||
const a1_erf: DataType = 0.254829592;
|
||||
const a2_erf: DataType = -0.284496736;
|
||||
const a3_erf: DataType = 1.421413741;
|
||||
const a4_erf: DataType = -1.453152027;
|
||||
const a5_erf: DataType = 1.061405429;
|
||||
const SQRT_2_INV: DataType = 0.7071067811865476;
|
||||
|
||||
fn op(a: DataType, b: DataType) -> DataType {
|
||||
let a_div_sqr2 = a * SQRT_2_INV;
|
||||
let sign_x = sign(a_div_sqr2);
|
||||
let x = abs(a_div_sqr2);
|
||||
let t = 1.0 / (1.0 + p_erf * x);
|
||||
let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
|
||||
let erf_approx = sign_x * y;
|
||||
return 0.5 * a * (1.0 + erf_approx) * b;
|
||||
}
|
||||
#endif
|
||||
#ifdef OP_GEGLU_QUICK
|
||||
const GELU_QUICK_COEF: DataType = -1.702;
|
||||
|
||||
fn op(a: DataType, b: DataType) -> DataType {
|
||||
return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
// Strides (in elements)
|
||||
stride_src01: u32,
|
||||
stride_src02: u32,
|
||||
stride_src03: u32,
|
||||
|
||||
stride_src11: u32,
|
||||
stride_src12: u32,
|
||||
stride_src13: u32,
|
||||
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
// shape of dst
|
||||
ne: u32,
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
|
||||
swapped: u32,
|
||||
alpha: f32,
|
||||
limit: f32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<DataType>;
|
||||
|
||||
#ifdef NO_SPLIT
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<DataType>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn a_value(base: u32) -> DataType {
|
||||
let offset: u32 = select(0, params.ne0, params.swapped != 0);
|
||||
return src0[base + offset];
|
||||
}
|
||||
|
||||
fn b_value(base: u32) -> DataType {
|
||||
let offset: u32 = select(params.ne0, 0, params.swapped != 0);
|
||||
return src0[base + offset];
|
||||
}
|
||||
|
||||
#else
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<DataType>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<DataType>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn a_value(base: u32) -> DataType {
|
||||
return src0[base];
|
||||
}
|
||||
|
||||
fn b_value(base: u32) -> DataType {
|
||||
return src1[base];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
var i = gid.x;
|
||||
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||
let i2 = i / (params.ne1 * params.ne0);
|
||||
i = i % (params.ne1 * params.ne0);
|
||||
let i1 = i / params.ne0;
|
||||
let i0 = i % params.ne0;
|
||||
|
||||
let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
|
||||
let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
|
||||
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
|
||||
|
||||
dst[i_dst] = op(a_value(i_a), b_value(i_b));
|
||||
}
|
||||
@@ -61,10 +61,10 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 18u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
@@ -81,14 +81,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 1u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
@@ -104,10 +102,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_1
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 20u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
@@ -124,15 +122,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_lo = f16(q_byte & 0xF) * d + m;
|
||||
@@ -149,11 +145,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_0
|
||||
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 22u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
// tile_k is defined as 32u, so blocks_k ends up being 1 always
|
||||
override BLOCKS_K = TILE_K / BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
|
||||
|
||||
@@ -171,18 +167,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let qh0 = src0[scale_idx + 1u];
|
||||
let qh1 = src0[scale_idx + 2u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
@@ -207,11 +199,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_1
|
||||
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 24u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
// tile_k is defined as 32u, so blocks_k ends up being 1 always
|
||||
override BLOCKS_K = TILE_K / BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
|
||||
|
||||
@@ -229,20 +221,16 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
let qh0 = src0[scale_idx + 2u];
|
||||
let qh1 = src0[scale_idx + 3u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
|
||||
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
@@ -266,10 +254,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q8_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 34u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
|
||||
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
|
||||
|
||||
@@ -286,14 +274,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_0 = src0[scale_idx + 1u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
@@ -308,10 +294,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q8_1
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 36u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
|
||||
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
|
||||
|
||||
@@ -328,15 +314,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
@@ -351,7 +335,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q2_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 42u;
|
||||
const BLOCK_SIZE_BYTES = 84u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
// Use standard thread layout instead of lane/row_group
|
||||
@@ -371,10 +355,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = src0[scale_idx + 40u];
|
||||
let dmin = src0[scale_idx + 41u];
|
||||
let d = load_src0_f16_at(block_byte_base + 80u);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 82u);
|
||||
|
||||
// Decode the element at position k_in_block
|
||||
let block_of_32 = k_in_block / 32u;
|
||||
@@ -387,18 +371,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
let is = k_in_block / 16u;
|
||||
|
||||
let sc_0 = src0[scale_idx + 2u * (is / 4u)];
|
||||
let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
|
||||
let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
|
||||
let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u));
|
||||
let sc = get_byte(sc_packed, is % 4u);
|
||||
|
||||
let dl = d * f16(sc & 0xFu);
|
||||
let ml = dmin * f16(sc >> 4u);
|
||||
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 3u;
|
||||
|
||||
@@ -410,7 +390,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q3_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 55u;
|
||||
const BLOCK_SIZE_BYTES = 110u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
@@ -429,9 +409,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = src0[scale_idx + 54u];
|
||||
let d = load_src0_f16_at(block_byte_base + 108u);
|
||||
|
||||
// Load and unpack scales
|
||||
let kmask1: u32 = 0x03030303u;
|
||||
@@ -439,9 +419,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0u; i < 4u; i++) {
|
||||
let scale_0 = src0[scale_idx + 48u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i);
|
||||
}
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
@@ -453,16 +431,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
// Load hmask and qs arrays
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0u; i < 8u; i++) {
|
||||
let hmask_0 = src0[scale_idx + (2u*i)];
|
||||
let hmask_1 = src0[scale_idx + (2u*i) + 1u];
|
||||
hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
|
||||
hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i);
|
||||
}
|
||||
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0u; i < 16u; i++) {
|
||||
let qs_0 = src0[scale_idx + 16u + (2u*i)];
|
||||
let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
|
||||
qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
|
||||
qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i);
|
||||
}
|
||||
|
||||
let half = k_in_block / 128u; // 0 or 1
|
||||
@@ -502,7 +476,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 72u;
|
||||
const BLOCK_SIZE_BYTES = 144u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
@@ -521,17 +495,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let dmin = src0[scale_idx + 1u];
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 2u);
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
let scale_0 = src0[scale_idx + 2u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
|
||||
}
|
||||
|
||||
// Map k_in_block to loop structure:
|
||||
@@ -567,9 +539,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 0xFu;
|
||||
@@ -582,7 +552,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 88u;
|
||||
const BLOCK_SIZE_BYTES = 176u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
@@ -601,17 +571,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let dmin = src0[scale_idx + 1u];
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 2u);
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
let scale_0 = src0[scale_idx + 2u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
|
||||
}
|
||||
|
||||
// The original loop processes elements in groups of 64
|
||||
@@ -651,15 +619,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
|
||||
let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
|
||||
let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u));
|
||||
|
||||
let qh_byte = get_byte(qh_packed, l % 4u);
|
||||
|
||||
@@ -675,7 +639,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q6_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 105u;
|
||||
const BLOCK_SIZE_BYTES = 210u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
@@ -694,7 +658,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let half = k_in_block / 128u;
|
||||
let pos_in_half = k_in_block % 128u;
|
||||
@@ -707,30 +671,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
// Load only ql13 word needed
|
||||
let ql13_flat = ql_b_idx + l;
|
||||
let ql13_word = ql13_flat / 4u;
|
||||
let ql13 = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 2u * ql13_word],
|
||||
src0[scale_idx + 2u * ql13_word + 1u]
|
||||
));
|
||||
let ql13_b = get_byte(ql13, ql13_flat % 4u);
|
||||
let ql13 = load_src0_u32_at(block_byte_base + ql13_flat);
|
||||
let ql13_b = get_byte(ql13, 0u);
|
||||
|
||||
// Load only ql24 word needed
|
||||
let ql24_flat = ql_b_idx + l + 32u;
|
||||
let ql24_word = ql24_flat / 4u;
|
||||
let ql24 = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 2u * ql24_word],
|
||||
src0[scale_idx + 2u * ql24_word + 1u]
|
||||
));
|
||||
let ql24_b = get_byte(ql24, ql24_flat % 4u);
|
||||
let ql24 = load_src0_u32_at(block_byte_base + ql24_flat);
|
||||
let ql24_b = get_byte(ql24, 0u);
|
||||
|
||||
// Load only qh word needed
|
||||
let qh_flat = qh_b_idx + l;
|
||||
let qh_word = qh_flat / 4u;
|
||||
let qh = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 64u + 2u * qh_word],
|
||||
src0[scale_idx + 64u + 2u * qh_word + 1u]
|
||||
));
|
||||
let qh_b = get_byte(qh, qh_flat % 4u);
|
||||
let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat);
|
||||
let qh_b = get_byte(qh, 0u);
|
||||
|
||||
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
|
||||
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
|
||||
@@ -740,14 +692,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
// Load only the scale word needed
|
||||
let is = l / 16u;
|
||||
let sc_idx = sc_b_idx + is + quarter * 2u;
|
||||
let sc_word = sc_idx / 4u;
|
||||
let sc = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 96u + 2u * sc_word],
|
||||
src0[scale_idx + 96u + 2u * sc_word + 1u]
|
||||
));
|
||||
let sc_val = get_byte_i32(sc, sc_idx % 4u);
|
||||
let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx);
|
||||
let sc_val = get_byte_i32(sc, 0u);
|
||||
|
||||
let d = src0[scale_idx + 104u];
|
||||
let d = load_src0_f16_at(block_byte_base + 208u);
|
||||
|
||||
var q_val: f16;
|
||||
if (quarter == 0u) {
|
||||
|
||||
@@ -52,8 +52,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
#ifdef MUL_ACC_Q4_0
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const BLOCK_SIZE_BYTES = 18u;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
@@ -62,14 +62,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 1 + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
@@ -86,8 +85,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
#ifdef MUL_ACC_Q4_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const BLOCK_SIZE_BYTES = 20u;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 10u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
@@ -96,15 +95,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = f32(src0[scale_idx + 1u]);
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = f32(load_src0_f16_at(block_byte_base + 2u));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
|
||||
@@ -121,8 +119,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
#ifdef MUL_ACC_Q5_0
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const BLOCK_SIZE_BYTES = 22u;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 11u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
@@ -131,18 +129,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let qh0 = src0[scale_idx + 1u];
|
||||
let qh1 = src0[scale_idx + 2u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
@@ -168,8 +163,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
#ifdef MUL_ACC_Q5_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const BLOCK_SIZE_BYTES = 24u;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 12u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
@@ -178,19 +173,16 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = src0[scale_idx + 1u];
|
||||
let qh0 = src0[scale_idx + 2u];
|
||||
let qh1 = src0[scale_idx + 3u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
@@ -216,8 +208,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
#ifdef MUL_ACC_Q8_0
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const BLOCK_SIZE_BYTES = 34u;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 17u;
|
||||
const WEIGHTS_PER_F16 = 2u;
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
@@ -226,15 +218,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 1 + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
@@ -250,8 +241,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
#ifdef MUL_ACC_Q8_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const BLOCK_SIZE_BYTES = 36u;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 18u;
|
||||
const WEIGHTS_PER_F16 = 2u;
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
@@ -260,16 +251,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = src0[scale_idx + 1u];
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d + f32(m);
|
||||
@@ -284,13 +274,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
#ifdef MUL_ACC_Q6_K
|
||||
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 105u;
|
||||
|
||||
fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
|
||||
let aligned = byte_offset & ~3u;
|
||||
let idx = bbase + aligned / 2u;
|
||||
return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
|
||||
}
|
||||
const BLOCK_SIZE_BYTES = 210u;
|
||||
|
||||
fn byte_of(v: u32, b: u32) -> u32 {
|
||||
return (v >> (b * 8u)) & 0xFFu;
|
||||
@@ -323,16 +307,15 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
|
||||
for (var i = ix; i < nb; i += 2u) {
|
||||
let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
|
||||
let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d_raw = load_u32_at(bbase, 208u);
|
||||
let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
|
||||
let d = f32(load_src0_f16_at(bbase + 208u));
|
||||
|
||||
let ql1_u32 = load_u32_at(bbase, q_offset_l);
|
||||
let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u);
|
||||
let qh_u32 = load_u32_at(bbase, 128u + q_offset_h);
|
||||
let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
|
||||
let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
|
||||
let ql1_u32 = load_src0_u32_at(bbase + q_offset_l);
|
||||
let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u);
|
||||
let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h);
|
||||
let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte);
|
||||
let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u);
|
||||
|
||||
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
|
||||
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
|
||||
|
||||
@@ -1,138 +1,12 @@
|
||||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_inplace",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_inplace",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_ff",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_ff_inplace",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_ff",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_ff_inplace",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(ROTATE)
|
||||
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
|
||||
dst[i_dst0] = {{TYPE}}(out0);
|
||||
dst[i_dst1] = {{TYPE}}(out1);
|
||||
}
|
||||
#enddecl(ROTATE)
|
||||
|
||||
#decl(ROTATE_INPLACE)
|
||||
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
|
||||
src0[i_dst0] = {{TYPE}}(out0);
|
||||
src0[i_dst1] = {{TYPE}}(out1);
|
||||
}
|
||||
#enddecl(ROTATE_INPLACE)
|
||||
|
||||
#decl(NO_FF_FUNC)
|
||||
fn freq_factor(i: u32) -> f32 {
|
||||
return 1.0f;
|
||||
}
|
||||
#enddecl(NO_FF_FUNC)
|
||||
|
||||
#decl(FF_FUNC)
|
||||
fn freq_factor(i: u32) -> f32 {
|
||||
return src2[params.offset_src2 + i/2];
|
||||
}
|
||||
#enddecl(FF_FUNC)
|
||||
|
||||
#decl(NO_FF_BINDINGS)
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#enddecl(NO_FF_BINDINGS)
|
||||
|
||||
#decl(NO_FF_BINDINGS_INPLACE)
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#enddecl(NO_FF_BINDINGS_INPLACE)
|
||||
|
||||
#decl(FF_BINDINGS)
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> src2: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#enddecl(FF_BINDINGS)
|
||||
|
||||
#decl(FF_BINDINGS_INPLACE)
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> src2: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#enddecl(FF_BINDINGS_INPLACE)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
#ifdef TYPE_F32
|
||||
#define DataType f32
|
||||
#endif
|
||||
#ifdef TYPE_F16
|
||||
#define DataType f16
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
@@ -168,12 +42,69 @@ struct Params {
|
||||
};
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||
|
||||
var<storage, read_write> src0: array<DataType>;
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<i32>;
|
||||
|
||||
DECLS
|
||||
#ifdef INPLACE
|
||||
|
||||
#ifdef FF_FUNC
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> src2: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#else
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#endif
|
||||
|
||||
#else
|
||||
|
||||
#ifdef FF_FUNC
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> src2: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> dst: array<DataType>;
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#else
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<DataType>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef FF_FUNC
|
||||
fn freq_factor(i: u32) -> f32 {
|
||||
return src2[params.offset_src2 + i/2];
|
||||
}
|
||||
|
||||
#else
|
||||
fn freq_factor(i: u32) -> f32 {
|
||||
return 1.0f;
|
||||
}
|
||||
#endif
|
||||
#ifdef INPLACE
|
||||
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
|
||||
src0[i_dst0] = DataType(out0);
|
||||
src0[i_dst1] = DataType(out1);
|
||||
}
|
||||
#else
|
||||
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
|
||||
dst[i_dst0] = DataType(out0);
|
||||
dst[i_dst1] = DataType(out1);
|
||||
}
|
||||
#endif
|
||||
|
||||
fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
|
||||
let y = (f32(i / 2) - low) / max(0.001f, high - low);
|
||||
@@ -184,7 +115,7 @@ fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
|
||||
// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
|
||||
fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
|
||||
var mscale = params.attn_factor;
|
||||
var theta = params.freq_scale * theta_extrap;
|
||||
var theta = params.freq_scale * theta_extrap;
|
||||
if (params.ext_factor != 0.0f) {
|
||||
let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
|
||||
theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
@@ -211,10 +142,9 @@ fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
|
||||
}
|
||||
}
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
// two elements per thread
|
||||
// two elements per n_threads
|
||||
if (gid.x >= params.n_threads) {
|
||||
return;
|
||||
}
|
||||
@@ -290,6 +220,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let x0 = f32(src0[i_src]);
|
||||
let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
|
||||
rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
}
|
||||
@@ -1,215 +1,12 @@
|
||||
#define(VARIANTS)
|
||||
[
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32",
|
||||
"DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_inplace",
|
||||
"DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_sink",
|
||||
"DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_sink_inplace",
|
||||
"DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f32",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f32_inplace",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f16",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f16_inplace",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f32_sink",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f16_sink",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
|
||||
}
|
||||
]
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(BASE_BINDINGS)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(BASE_BINDINGS)
|
||||
|
||||
#decl(BASE_BINDINGS_INPLACE)
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(BASE_BINDINGS_INPLACE)
|
||||
|
||||
#decl(SINK_BINDINGS)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> sinks: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(SINK_BINDINGS)
|
||||
|
||||
#decl(SINK_BINDINGS_INPLACE)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> sinks: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(SINK_BINDINGS_INPLACE)
|
||||
|
||||
#decl(MASK_BINDINGS)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(MASK_BINDINGS)
|
||||
|
||||
#decl(MASK_BINDINGS_INPLACE)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(MASK_BINDINGS_INPLACE)
|
||||
|
||||
#decl(MASK_SINK_BINDINGS)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> sinks: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(MASK_SINK_BINDINGS)
|
||||
|
||||
#decl(MASK_SINK_BINDINGS_INPLACE)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> sinks: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(MASK_SINK_BINDINGS_INPLACE)
|
||||
|
||||
#decl(NOT_INPLACE)
|
||||
fn inter_value(i: u32) -> f32 {
|
||||
return dst[i];
|
||||
}
|
||||
|
||||
fn update(i: u32, val: f32) {
|
||||
dst[i] = val;
|
||||
}
|
||||
#enddecl(NOT_INPLACE)
|
||||
|
||||
#decl(INPLACE)
|
||||
fn inter_value(i: u32) -> f32 {
|
||||
return src[i];
|
||||
}
|
||||
|
||||
fn update(i: u32, val: f32) {
|
||||
src[i] = val;
|
||||
}
|
||||
#enddecl(INPLACE)
|
||||
|
||||
#decl(NO_MASK)
|
||||
fn mask_val(i: u32) -> f32 {
|
||||
return 0.0;
|
||||
}
|
||||
#enddecl(NO_MASK)
|
||||
|
||||
#decl(MASK)
|
||||
fn mask_val(i: u32) -> f32 {
|
||||
return f32(mask[i]);
|
||||
}
|
||||
#enddecl(MASK)
|
||||
|
||||
#decl(NO_SINK)
|
||||
fn lower_max_bound(i2: u32) -> f32 {
|
||||
return -1e30;
|
||||
}
|
||||
|
||||
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
|
||||
return val;
|
||||
}
|
||||
#enddecl(NO_SINK)
|
||||
|
||||
#decl(SINK)
|
||||
fn lower_max_bound(i2: u32) -> f32 {
|
||||
return sinks[params.offset_sinks + i2];
|
||||
}
|
||||
|
||||
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
|
||||
return val + exp(sinks[params.offset_sinks + i2] - max_val);
|
||||
}
|
||||
#enddecl(SINK)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
enable f16;
|
||||
|
||||
#ifdef MASK_F32
|
||||
#define MaskType f32
|
||||
#endif
|
||||
#ifdef MASK_F16
|
||||
#define MaskType f16
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
@@ -249,14 +46,117 @@ struct Params {
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
DECLS
|
||||
#ifdef HAS_MASK
|
||||
#ifdef HAS_SINK
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> mask: array<MaskType>;
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> sinks: array<f32>;
|
||||
|
||||
#ifdef INPLACE
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#else
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
@group(0) @binding(4)
|
||||
var<uniform> params: Params;
|
||||
#endif
|
||||
|
||||
#else
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> mask: array<MaskType>;
|
||||
|
||||
#ifdef INPLACE
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#else
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#else
|
||||
#ifdef HAS_SINK
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> sinks: array<f32>;
|
||||
|
||||
#ifdef INPLACE
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#else
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
#endif
|
||||
|
||||
#else
|
||||
#ifdef INPLACE
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
#else
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef INPLACE
|
||||
fn inter_value(i: u32) -> f32 {
|
||||
return src[i];
|
||||
}
|
||||
fn update(i: u32, val: f32) {
|
||||
src[i] = val;
|
||||
}
|
||||
|
||||
#else
|
||||
fn inter_value(i: u32) -> f32 {
|
||||
return dst[i];
|
||||
}
|
||||
fn update(i: u32, val: f32) {
|
||||
dst[i] = val;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef HAS_MASK
|
||||
fn mask_val(i: u32) -> f32 {
|
||||
return f32(mask[i]);
|
||||
}
|
||||
|
||||
#else
|
||||
fn mask_val(i: u32) -> f32 {
|
||||
return 0.0;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef HAS_SINK
|
||||
fn lower_max_bound(i2: u32) -> f32 {
|
||||
return sinks[params.offset_sinks + i2];
|
||||
}
|
||||
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
|
||||
return val + exp(sinks[params.offset_sinks + i2] - max_val);
|
||||
}
|
||||
#else
|
||||
fn lower_max_bound(i2: u32) -> f32 {
|
||||
return -1e30;
|
||||
}
|
||||
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
|
||||
return val;
|
||||
}
|
||||
#endif
|
||||
|
||||
const CACHE_SIZE: u32 = 16;
|
||||
var<workgroup> scratch: array<f32, WG_SIZE>;
|
||||
|
||||
override wg_size: u32;
|
||||
var<workgroup> scratch: array<f32, wg_size>;
|
||||
|
||||
@compute @workgroup_size(wg_size)
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
@builtin(local_invocation_id) lid: vec3<u32>) {
|
||||
|
||||
@@ -268,7 +168,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
|
||||
let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
|
||||
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
|
||||
let elems = (params.ne0 + wg_size - 1) / wg_size;
|
||||
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
|
||||
|
||||
let head = f32(i2);
|
||||
let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
|
||||
@@ -286,12 +186,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
if (col < CACHE_SIZE) {
|
||||
cache[col] = val;
|
||||
}
|
||||
col += wg_size;
|
||||
col += WG_SIZE;
|
||||
}
|
||||
|
||||
scratch[lid.x] = max_val;
|
||||
workgroupBarrier();
|
||||
var offset = wg_size / 2;
|
||||
var offset: u32 = WG_SIZE / 2;
|
||||
while (offset > 0) {
|
||||
if (lid.x < offset) {
|
||||
scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
|
||||
@@ -317,12 +217,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
} else {
|
||||
update(i_dst_row + col, ex);
|
||||
}
|
||||
col += wg_size;
|
||||
col += WG_SIZE;
|
||||
}
|
||||
|
||||
scratch[lid.x] = sum;
|
||||
workgroupBarrier();
|
||||
offset = wg_size / 2;
|
||||
offset = WG_SIZE / 2;
|
||||
while (offset > 0) {
|
||||
if (lid.x < offset) {
|
||||
scratch[lid.x] += scratch[lid.x + offset];
|
||||
@@ -339,7 +239,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
break;
|
||||
}
|
||||
update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
|
||||
col += wg_size;
|
||||
col += WG_SIZE;
|
||||
}
|
||||
}
|
||||
#end(SHADER)
|
||||
|
||||
@@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
|
||||
ExternalProject_Add(
|
||||
zendnn
|
||||
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
|
||||
GIT_TAG a18adf8c605fb5f5e52cefd7eda08a7b18febbaf # ZenDNN-2026-WW08
|
||||
GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13
|
||||
PREFIX ${ZENDNN_PREFIX}
|
||||
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
|
||||
BINARY_DIR ${ZENDNN_BUILD_DIR}
|
||||
|
||||
@@ -190,6 +190,170 @@ static void ggml_zendnn_compute_forward_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
static void ggml_zendnn_compute_forward_mul_mat_id(
|
||||
ggml_backend_zendnn_context * ctx,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0]; // expert weights
|
||||
const ggml_tensor * src1 = dst->src[1]; // inputs
|
||||
const ggml_tensor * ids = dst->src[2]; // expert ids
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
// exit for no tokens to process
|
||||
if (ne2 == 0 || ne11 == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_type const vec_dot_type = src0->type;
|
||||
ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref;
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
GGML_ASSERT(ne03 == 1);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
GGML_ASSERT(ne3 == 1);
|
||||
|
||||
// row groups
|
||||
const int n_ids = ids->ne[0]; // n_expert_used
|
||||
const int n_as = ne02; // n_experts
|
||||
|
||||
std::vector<int64_t> matrix_row_counts(n_as, 0);
|
||||
std::vector<std::vector<mmid_row_mapping>> matrix_rows(n_as);
|
||||
|
||||
int64_t max_rows = 0;
|
||||
// group rows by expert (preprocessing step)
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
|
||||
for (int id = 0; id < n_ids; ++id) {
|
||||
const int32_t i02 = *(const int32_t *)((const char *)ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
||||
|
||||
matrix_rows[i02].push_back({id, iid1});
|
||||
matrix_row_counts[i02]++;
|
||||
if (matrix_row_counts[i02] > max_rows) {
|
||||
max_rows = matrix_row_counts[i02];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (max_rows == 0) {
|
||||
return; // no rows to process
|
||||
}
|
||||
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
|
||||
// size for converting src1 rows to vec_dot_type if needed
|
||||
const size_t nbw1 = row_size;
|
||||
const size_t nbw2 = nbw1 * ne11;
|
||||
const size_t nbw3 = nbw2 * ne12;
|
||||
const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0;
|
||||
|
||||
// size for MoE gather/scatter buffers
|
||||
const size_t wdata_cur_size = max_rows * row_size;
|
||||
const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01);
|
||||
|
||||
// allocate single buffer for all needs
|
||||
const size_t total_size = src1_conv_size + wdata_cur_size + dst_cur_size;
|
||||
if (ctx->work_size < total_size) {
|
||||
ctx->work_data.reset(new char[total_size]);
|
||||
ctx->work_size = total_size;
|
||||
}
|
||||
|
||||
// partition the buffer
|
||||
char * work_data = ctx->work_data.get();
|
||||
char * wdata_cur = work_data + src1_conv_size;
|
||||
char * dst_cur = wdata_cur + wdata_cur_size;
|
||||
|
||||
if (src1->type != vec_dot_type) {
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
#pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static)
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
const float * src1_f32 = (float *)((char *)src1->data + i11*nb11 + i12*nb12 + i13*nb13);
|
||||
void * src1_conv = (char *)work_data + i11*nbw1 + i12*nbw2 + i13*nbw3;
|
||||
from_float(src1_f32, src1_conv, ne10);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const void * wdata = src1->type == vec_dot_type ? src1->data : work_data;
|
||||
|
||||
// process each expert with gather -> gemm -> scatter pattern
|
||||
for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) {
|
||||
const int64_t cne1 = matrix_row_counts[cur_a];
|
||||
|
||||
if (cne1 == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const char * src0_cur = (const char *) src0->data + cur_a*nb02;
|
||||
|
||||
// gather input rows for this expert
|
||||
#pragma omp parallel for num_threads(ctx->n_threads) schedule(static)
|
||||
for (int64_t ir1 = 0; ir1 < cne1; ++ir1) {
|
||||
const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1];
|
||||
const int64_t id = row_mapping.i1;
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = row_mapping.i2;
|
||||
|
||||
std::memcpy(
|
||||
wdata_cur + ir1 * row_size,
|
||||
(const char *) wdata + (i11 + i12*ne11) * row_size,
|
||||
row_size
|
||||
);
|
||||
}
|
||||
|
||||
// batched gemm for all tokens in this expert
|
||||
if (!ggml_zendnn_sgemm(ctx,
|
||||
ne01, // m
|
||||
cne1, // n
|
||||
ne10, // k
|
||||
src0_cur,
|
||||
ne00, // lda
|
||||
wdata_cur,
|
||||
ne10, // ldb
|
||||
dst_cur,
|
||||
ne01, // ldc
|
||||
src0->type,
|
||||
vec_dot_type,
|
||||
dst->type)) {
|
||||
GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
|
||||
}
|
||||
|
||||
// scatter output rows to destination
|
||||
#pragma omp parallel for num_threads(ctx->n_threads) schedule(static)
|
||||
for (int64_t ir1 = 0; ir1 < cne1; ++ir1) {
|
||||
const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1];
|
||||
const int64_t id = row_mapping.i1;
|
||||
const int64_t i1 = id;
|
||||
const int64_t i2 = row_mapping.i2;
|
||||
|
||||
std::memcpy(
|
||||
(char *) dst->data + i1*nb1 + i2*nb2,
|
||||
dst_cur + ir1 * ggml_row_size(dst->type, ne01),
|
||||
ggml_row_size(dst->type, ne01)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// backend interface
|
||||
|
||||
static const char * ggml_backend_zendnn_get_name(ggml_backend_t backend) {
|
||||
@@ -218,6 +382,9 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm
|
||||
case GGML_OP_MUL_MAT:
|
||||
ggml_zendnn_compute_forward_mul_mat(ctx, node);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
ggml_zendnn_compute_forward_mul_mat_id(ctx, node);
|
||||
break;
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
@@ -361,6 +528,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
|
||||
return true;
|
||||
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
const ggml_tensor * weights = op->src[0];
|
||||
const ggml_tensor * inputs = op->src[1];
|
||||
@@ -374,6 +542,17 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
|
||||
ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
|
||||
return false;
|
||||
}
|
||||
// MUL_MAT_ID performs best with a moderate number of experts due to its
|
||||
// gather + batched matmul + scatter approach. Future versions will leverage
|
||||
// ZenDNN's grouped_gemm for better scalability with larger expert counts:
|
||||
// https://github.com/amd/ZenDNN/blob/main/docs/operator/lowoha_group_gemm_operator.md
|
||||
if (op->op == GGML_OP_MUL_MAT_ID) {
|
||||
const int64_t n_experts = weights->ne[2];
|
||||
const int64_t max_experts = 32;
|
||||
if (n_experts > max_experts) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
switch (weights->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_BF16:
|
||||
|
||||
@@ -419,6 +419,7 @@ class MODEL_ARCH(IntEnum):
|
||||
GEMMA2 = auto()
|
||||
GEMMA3 = auto()
|
||||
GEMMA3N = auto()
|
||||
GEMMA4 = auto()
|
||||
GEMMA_EMBEDDING = auto()
|
||||
STARCODER2 = auto()
|
||||
RWKV6 = auto()
|
||||
@@ -535,8 +536,11 @@ class MODEL_TENSOR(IntEnum):
|
||||
FFN_GATE_INP = auto()
|
||||
FFN_GATE_INP_SHEXP = auto()
|
||||
FFN_NORM = auto()
|
||||
FFN_PRE_NORM = auto()
|
||||
FFN_PRE_NORM = auto() # alias of FFN_NORM
|
||||
FFN_PRE_NORM_2 = auto() # gemma4
|
||||
FFN_POST_NORM = auto()
|
||||
FFN_POST_NORM_1 = auto() # gemma4
|
||||
FFN_POST_NORM_2 = auto() # gemma4
|
||||
FFN_GATE = auto()
|
||||
FFN_DOWN = auto()
|
||||
FFN_UP = auto()
|
||||
@@ -558,6 +562,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
ATTN_Q_NORM = auto()
|
||||
ATTN_K_NORM = auto()
|
||||
LAYER_OUT_NORM = auto()
|
||||
LAYER_OUT_SCALE = auto()
|
||||
PER_LAYER_TOKEN_EMBD = auto() # gemma3n
|
||||
PER_LAYER_MODEL_PROJ = auto() # gemma3n
|
||||
PER_LAYER_INP_GATE = auto() # gemma3n
|
||||
@@ -722,10 +727,14 @@ class MODEL_TENSOR(IntEnum):
|
||||
V_ENC_FFN_UP = auto()
|
||||
V_ENC_FFN_GATE = auto()
|
||||
V_ENC_FFN_DOWN = auto()
|
||||
V_ENC_ATTN_POST_NORM = auto() # gemma4
|
||||
V_ENC_FFN_POST_NORM = auto()
|
||||
V_LAYER_SCALE_1 = auto()
|
||||
V_LAYER_SCALE_2 = auto()
|
||||
V_LAYER_OUT_SCALE = auto()
|
||||
V_PRE_NORM = auto()
|
||||
V_POST_NORM = auto()
|
||||
V_MM_PRE_NORM = auto() # hunyuanocr
|
||||
V_MM_POST_NORM = auto()
|
||||
V_MM_INP_NORM = auto()
|
||||
V_MM_INP_PROJ = auto() # gemma3
|
||||
@@ -761,6 +770,10 @@ class MODEL_TENSOR(IntEnum):
|
||||
V_MM_GATE = auto() # cogvlm
|
||||
V_TOK_BOI = auto() # cogvlm
|
||||
V_TOK_EOI = auto() # cogvlm
|
||||
V_TOK_IMG_BEGIN = auto() # hunyuanocr
|
||||
V_TOK_IMG_END = auto() # hunyuanocr
|
||||
V_STD_BIAS = auto() # gemma4
|
||||
V_STD_SCALE = auto() # gemma4
|
||||
V_SAM_POS_EMBD = auto() # Deepseek-OCR
|
||||
V_SAM_PATCH_EMBD = auto() # Deepseek-OCR
|
||||
V_SAM_PRE_NORM = auto() # Deepseek-OCR
|
||||
@@ -781,6 +794,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
A_ENC_EMBD_POS = auto()
|
||||
A_ENC_EMBD_NORM = auto()
|
||||
A_ENC_EMBD_TO_LOGITS = auto() # lfm2
|
||||
A_ENC_INP_PROJ = auto() # gemma4
|
||||
A_ENC_CONV1D = auto()
|
||||
A_ENC_CONV1D_NORM = auto() # gemma3n
|
||||
A_PRE_NORM = auto()
|
||||
@@ -789,10 +803,13 @@ class MODEL_TENSOR(IntEnum):
|
||||
A_ENC_ATTN_Q = auto()
|
||||
A_ENC_ATTN_K = auto()
|
||||
A_ENC_ATTN_V = auto()
|
||||
A_ENC_ATTN_POST_NORM = auto()
|
||||
A_ENC_ATTN_PRE_NORM = auto()
|
||||
A_ENC_ATTN_K_REL = auto() # gemma4
|
||||
A_ENC_PER_DIM_SCALE = auto() # gemma3n
|
||||
A_ENC_INPUT_NORM = auto()
|
||||
A_ENC_OUTPUT = auto()
|
||||
A_ENC_OUTPUT_NORM = auto()
|
||||
A_ENC_OUTPUT = auto() # TODO @ngxson: rename to ATTN_OUT
|
||||
A_ENC_OUTPUT_NORM = auto() # TODO @ngxson: rename to ATTN_OUT
|
||||
A_ENC_FFN_UP = auto()
|
||||
A_ENC_FFN_NORM = auto()
|
||||
A_ENC_FFN_POST_NORM = auto() # gemma3n
|
||||
@@ -813,6 +830,8 @@ class MODEL_TENSOR(IntEnum):
|
||||
A_MM_HARD_EMB_NORM = auto() # gemma3n
|
||||
A_MM_SOFT_EMB_NORM = auto() # gemma3n
|
||||
A_MM_INP_PROJ = auto() # gemma3n
|
||||
A_PER_DIM_K_SCALE = auto() # gemma4
|
||||
A_PER_DIM_SCALE = auto() # gemma4
|
||||
# nextn/mtp
|
||||
NEXTN_EH_PROJ = auto()
|
||||
NEXTN_EMBED_TOKENS = auto()
|
||||
@@ -882,6 +901,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.GEMMA2: "gemma2",
|
||||
MODEL_ARCH.GEMMA3: "gemma3",
|
||||
MODEL_ARCH.GEMMA3N: "gemma3n",
|
||||
MODEL_ARCH.GEMMA4: "gemma4",
|
||||
MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
@@ -1000,6 +1020,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm",
|
||||
MODEL_TENSOR.FFN_PRE_NORM_2: "blk.{bid}.pre_ffw_norm_2", # gemma4
|
||||
MODEL_TENSOR.FFN_POST_NORM_1: "blk.{bid}.post_ffw_norm_1", # gemma4
|
||||
MODEL_TENSOR.FFN_POST_NORM_2: "blk.{bid}.post_ffw_norm_2", # gemma4
|
||||
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
|
||||
@@ -1019,6 +1042,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.MOE_LATENT_DOWN: "blk.{bid}.ffn_latent_down", # nemotron 3 super
|
||||
MODEL_TENSOR.MOE_LATENT_UP: "blk.{bid}.ffn_latent_up", # nemotron 3 super
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE: "blk.{bid}.layer_output_scale",
|
||||
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
|
||||
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n
|
||||
MODEL_TENSOR.PER_LAYER_PROJ_NORM: "per_layer_proj_norm", # gemma3n
|
||||
@@ -1183,8 +1207,11 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.V_ENC_ATTN_POST_NORM: "v.blk.{bid}.attn_post_norm",
|
||||
MODEL_TENSOR.V_ENC_FFN_POST_NORM: "v.blk.{bid}.ffn_post_norm",
|
||||
MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1",
|
||||
MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2",
|
||||
MODEL_TENSOR.V_LAYER_OUT_SCALE: "v.blk.{bid}.out_scale",
|
||||
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
|
||||
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
|
||||
MODEL_TENSOR.V_MM_POST_NORM: "mm.post_norm",
|
||||
@@ -1222,6 +1249,11 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.V_MM_GATE: "mm.gate",
|
||||
MODEL_TENSOR.V_TOK_BOI: "v.boi",
|
||||
MODEL_TENSOR.V_TOK_EOI: "v.eoi",
|
||||
MODEL_TENSOR.V_MM_PRE_NORM: "mm.pre_norm",
|
||||
MODEL_TENSOR.V_TOK_IMG_BEGIN: "mm.image_begin",
|
||||
MODEL_TENSOR.V_TOK_IMG_END: "mm.image_end",
|
||||
MODEL_TENSOR.V_STD_BIAS: "v.std_bias", # gemma4
|
||||
MODEL_TENSOR.V_STD_SCALE: "v.std_scale", # gemma4
|
||||
# DeepSeek-OCR SAM
|
||||
MODEL_TENSOR.V_SAM_POS_EMBD: "v.sam.pos_embd",
|
||||
MODEL_TENSOR.V_SAM_PATCH_EMBD: "v.sam.patch_embd",
|
||||
@@ -1243,6 +1275,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
|
||||
MODEL_TENSOR.A_ENC_EMBD_NORM: "a.position_embd_norm",
|
||||
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits",
|
||||
MODEL_TENSOR.A_ENC_INP_PROJ: "a.input_projection",
|
||||
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
|
||||
MODEL_TENSOR.A_ENC_CONV1D_NORM: "a.conv1d.{bid}.norm",
|
||||
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
|
||||
@@ -1251,6 +1284,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q",
|
||||
MODEL_TENSOR.A_ENC_ATTN_K: "a.blk.{bid}.attn_k",
|
||||
MODEL_TENSOR.A_ENC_ATTN_V: "a.blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.A_ENC_ATTN_POST_NORM: "a.blk.{bid}.attn_post_norm",
|
||||
MODEL_TENSOR.A_ENC_ATTN_PRE_NORM: "a.blk.{bid}.attn_pre_norm",
|
||||
MODEL_TENSOR.A_ENC_ATTN_K_REL: "a.blk.{bid}.attn_k_rel",
|
||||
MODEL_TENSOR.A_ENC_PER_DIM_SCALE: "a.blk.{bid}.per_dim_scale",
|
||||
MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1",
|
||||
MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out",
|
||||
@@ -1275,6 +1311,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.A_MM_SOFT_EMB_NORM: "mm.a.soft_emb_norm", # gemma3n
|
||||
MODEL_TENSOR.A_MM_EMBEDDING: "mm.a.embedding", # gemma3n
|
||||
MODEL_TENSOR.A_MM_HARD_EMB_NORM: "mm.a.hard_emb_norm", # gemma3n
|
||||
MODEL_TENSOR.A_PER_DIM_K_SCALE: "a.blk.{bid}.per_dim_k_scale", # gemma4
|
||||
MODEL_TENSOR.A_PER_DIM_SCALE: "a.blk.{bid}.per_dim_scale", # gemma4
|
||||
# lfm2 audio
|
||||
MODEL_TENSOR.A_ENC_NORM_CONV: "a.blk.{bid}.norm_conv",
|
||||
MODEL_TENSOR.A_ENC_LINEAR_POS: "a.blk.{bid}.linear_pos",
|
||||
@@ -1319,8 +1357,11 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.V_ENC_FFN_UP,
|
||||
MODEL_TENSOR.V_ENC_FFN_GATE,
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN,
|
||||
MODEL_TENSOR.V_ENC_ATTN_POST_NORM,
|
||||
MODEL_TENSOR.V_ENC_FFN_POST_NORM,
|
||||
MODEL_TENSOR.V_LAYER_SCALE_1,
|
||||
MODEL_TENSOR.V_LAYER_SCALE_2,
|
||||
MODEL_TENSOR.V_LAYER_OUT_SCALE,
|
||||
MODEL_TENSOR.V_PRE_NORM,
|
||||
MODEL_TENSOR.V_POST_NORM,
|
||||
MODEL_TENSOR.V_MM_POST_NORM,
|
||||
@@ -1358,6 +1399,11 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.V_MM_GATE,
|
||||
MODEL_TENSOR.V_TOK_BOI,
|
||||
MODEL_TENSOR.V_TOK_EOI,
|
||||
MODEL_TENSOR.V_MM_PRE_NORM,
|
||||
MODEL_TENSOR.V_TOK_IMG_BEGIN,
|
||||
MODEL_TENSOR.V_TOK_IMG_END,
|
||||
MODEL_TENSOR.V_STD_BIAS,
|
||||
MODEL_TENSOR.V_STD_SCALE,
|
||||
MODEL_TENSOR.V_SAM_POS_EMBD,
|
||||
MODEL_TENSOR.V_SAM_PATCH_EMBD,
|
||||
MODEL_TENSOR.V_SAM_PRE_NORM,
|
||||
@@ -1375,6 +1421,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.A_ENC_EMBD_POS,
|
||||
MODEL_TENSOR.A_ENC_EMBD_NORM,
|
||||
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS,
|
||||
MODEL_TENSOR.A_ENC_INP_PROJ,
|
||||
MODEL_TENSOR.A_ENC_CONV1D,
|
||||
MODEL_TENSOR.A_ENC_CONV1D_NORM,
|
||||
MODEL_TENSOR.A_PRE_NORM,
|
||||
@@ -1383,6 +1430,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.A_ENC_ATTN_Q,
|
||||
MODEL_TENSOR.A_ENC_ATTN_K,
|
||||
MODEL_TENSOR.A_ENC_ATTN_V,
|
||||
MODEL_TENSOR.A_ENC_ATTN_POST_NORM,
|
||||
MODEL_TENSOR.A_ENC_ATTN_PRE_NORM,
|
||||
MODEL_TENSOR.A_ENC_ATTN_K_REL,
|
||||
MODEL_TENSOR.A_ENC_PER_DIM_SCALE,
|
||||
MODEL_TENSOR.A_ENC_INPUT_NORM,
|
||||
MODEL_TENSOR.A_ENC_OUTPUT,
|
||||
@@ -1416,6 +1466,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.A_MM_SOFT_EMB_NORM,
|
||||
MODEL_TENSOR.A_MM_EMBEDDING,
|
||||
MODEL_TENSOR.A_MM_HARD_EMB_NORM,
|
||||
MODEL_TENSOR.A_PER_DIM_K_SCALE,
|
||||
MODEL_TENSOR.A_PER_DIM_SCALE,
|
||||
],
|
||||
MODEL_ARCH.LLAMA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
@@ -2273,6 +2325,38 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.LAUREL_R,
|
||||
MODEL_TENSOR.LAUREL_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.GEMMA4: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_UP_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
MODEL_TENSOR.FFN_PRE_NORM_2,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM_1,
|
||||
MODEL_TENSOR.FFN_POST_NORM_2,
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE,
|
||||
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
|
||||
MODEL_TENSOR.PER_LAYER_MODEL_PROJ,
|
||||
MODEL_TENSOR.PER_LAYER_INP_GATE,
|
||||
MODEL_TENSOR.PER_LAYER_PROJ,
|
||||
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
|
||||
MODEL_TENSOR.PER_LAYER_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.GEMMA_EMBEDDING: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
@@ -4010,6 +4094,8 @@ class VisionProjectorType:
|
||||
GEMMA3 = "gemma3"
|
||||
GEMMA3NV = "gemma3nv"
|
||||
GEMMA3NA = "gemma3na"
|
||||
GEMMA4V = "gemma4v"
|
||||
GEMMA4A = "gemma4a"
|
||||
PHI4 = "phi4"
|
||||
IDEFICS3 = "idefics3"
|
||||
PIXTRAL = "pixtral"
|
||||
@@ -4036,6 +4122,7 @@ class VisionProjectorType:
|
||||
GLM4V = "glm4v"
|
||||
YOUTUVL = "youtuvl"
|
||||
NEMOTRON_V2_VL = "nemotron_v2_vl"
|
||||
HUNYUANOCR = "hunyuanocr"
|
||||
|
||||
|
||||
# Items here are (block size, type size)
|
||||
|
||||
@@ -799,6 +799,7 @@ class GGUFWriter:
|
||||
def add_shared_kv_layers(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
|
||||
|
||||
# if input is array, true means SWA and false means full_attention for each layer
|
||||
def add_sliding_window_pattern(self, value: int | Sequence[bool]) -> None:
|
||||
key = Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch)
|
||||
if isinstance(value, int):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user