mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-23 16:37:33 +03:00
Compare commits
16 Commits
b7144
...
compilade/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ef41855cf | ||
|
|
f88a4b9398 | ||
|
|
4be1a5d44b | ||
|
|
6ffa46d8f4 | ||
|
|
3126b5ee4e | ||
|
|
e097d98a22 | ||
|
|
5712aa895f | ||
|
|
d3fcb0e90e | ||
|
|
614b95a88d | ||
|
|
c3738cfcef | ||
|
|
791bd97b3c | ||
|
|
d921057027 | ||
|
|
562aa42c12 | ||
|
|
e996f3aef8 | ||
|
|
e7b7ed8ab1 | ||
|
|
c4b630f25d |
@@ -3,8 +3,7 @@
|
||||
# ==============================================================================
|
||||
|
||||
# Define the CANN base image for easier version updates later
|
||||
ARG CHIP_TYPE=910b
|
||||
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.3.rc1.alpha001-${CHIP_TYPE}-openeuler22.03-py3.11
|
||||
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.1.rc1-910b-openeuler22.03-py3.10
|
||||
|
||||
# ==============================================================================
|
||||
# BUILD STAGE
|
||||
@@ -12,6 +11,9 @@ ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.3.rc1.alpha001-${CHIP_TYPE}-openeuler2
|
||||
# ==============================================================================
|
||||
FROM ${CANN_BASE_IMAGE} AS build
|
||||
|
||||
# Define the Ascend chip model for compilation. Default is Ascend910B3
|
||||
ARG ASCEND_SOC_TYPE=Ascend910B3
|
||||
|
||||
# -- Install build dependencies --
|
||||
RUN yum install -y gcc g++ cmake make git libcurl-devel python3 python3-pip && \
|
||||
yum clean all && \
|
||||
@@ -34,21 +36,20 @@ ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH
|
||||
# For brevity, only core variables are listed here. You can paste the original ENV list here.
|
||||
|
||||
# -- Build llama.cpp --
|
||||
# Use the passed CHIP_TYPE argument and add general build options
|
||||
ARG CHIP_TYPE
|
||||
# Use the passed ASCEND_SOC_TYPE argument and add general build options
|
||||
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh --force \
|
||||
&& \
|
||||
cmake -B build \
|
||||
-DGGML_CANN=ON \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DSOC_TYPE=ascend${CHIP_TYPE} \
|
||||
-DSOC_TYPE=${ASCEND_SOC_TYPE} \
|
||||
. && \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
# -- Organize build artifacts for copying in later stages --
|
||||
# Create a lib directory to store all .so files
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
find build -name "*.so" -exec cp {} /app/lib \;
|
||||
|
||||
# Create a full directory to store all executables and Python scripts
|
||||
RUN mkdir -p /app/full && \
|
||||
|
||||
@@ -20,7 +20,7 @@ RUN if [ "$TARGETARCH" = "amd64" ] || [ "$TARGETARCH" = "arm64" ]; then \
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
find build -name "*.so" -exec cp {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
|
||||
@@ -25,7 +25,7 @@ RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
find build -name "*.so" -exec cp {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
|
||||
@@ -21,7 +21,7 @@ RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
find build -name "*.so" -exec cp {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
|
||||
@@ -32,7 +32,7 @@ RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
find build -name "*.so" -exec cp {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
|
||||
@@ -34,7 +34,6 @@
|
||||
rocmGpuTargets ? builtins.concatStringsSep ";" rocmPackages.clr.gpuTargets,
|
||||
enableCurl ? true,
|
||||
useVulkan ? false,
|
||||
useRpc ? false,
|
||||
llamaVersion ? "0.0.0", # Arbitrary version, substituted by the flake
|
||||
|
||||
# It's necessary to consistently use backendStdenv when building with CUDA support,
|
||||
@@ -176,7 +175,6 @@ effectiveStdenv.mkDerivation (finalAttrs: {
|
||||
(cmakeBool "GGML_METAL" useMetalKit)
|
||||
(cmakeBool "GGML_VULKAN" useVulkan)
|
||||
(cmakeBool "GGML_STATIC" enableStatic)
|
||||
(cmakeBool "GGML_RPC" useRpc)
|
||||
]
|
||||
++ optionals useCuda [
|
||||
(
|
||||
|
||||
@@ -45,7 +45,7 @@ RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
|
||||
&& cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib \
|
||||
&& find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
&& find build -name "*.so" -exec cp {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
|
||||
@@ -1,24 +1,42 @@
|
||||
ARG UBUNTU_VERSION=26.04
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
# Ref: https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html
|
||||
|
||||
# Install build tools
|
||||
RUN apt update && apt install -y git build-essential cmake wget xz-utils
|
||||
|
||||
# Install Vulkan SDK
|
||||
ARG VULKAN_VERSION=1.4.321.1
|
||||
RUN ARCH=$(uname -m) && \
|
||||
wget -qO /tmp/vulkan-sdk.tar.xz https://sdk.lunarg.com/sdk/download/${VULKAN_VERSION}/linux/vulkan-sdk-linux-${ARCH}-${VULKAN_VERSION}.tar.xz && \
|
||||
mkdir -p /opt/vulkan && \
|
||||
tar -xf /tmp/vulkan-sdk.tar.xz -C /tmp --strip-components=1 && \
|
||||
mv /tmp/${ARCH}/* /opt/vulkan/ && \
|
||||
rm -rf /tmp/*
|
||||
|
||||
# Install cURL and Vulkan SDK dependencies
|
||||
RUN apt install -y libcurl4-openssl-dev curl \
|
||||
libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libvulkan-dev glslc
|
||||
libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev
|
||||
|
||||
# Set environment variables
|
||||
ENV VULKAN_SDK=/opt/vulkan
|
||||
ENV PATH=$VULKAN_SDK/bin:$PATH
|
||||
ENV LD_LIBRARY_PATH=$VULKAN_SDK/lib:$LD_LIBRARY_PATH
|
||||
ENV CMAKE_PREFIX_PATH=$VULKAN_SDK:$CMAKE_PREFIX_PATH
|
||||
ENV PKG_CONFIG_PATH=$VULKAN_SDK/lib/pkgconfig:$PKG_CONFIG_PATH
|
||||
|
||||
# Build it
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=ON -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \
|
||||
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=1 -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
find build -name "*.so" -exec cp {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
@@ -32,7 +50,7 @@ RUN mkdir -p /app/full \
|
||||
FROM ubuntu:$UBUNTU_VERSION AS base
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl libvulkan1 mesa-vulkan-drivers \
|
||||
&& apt-get install -y libgomp1 curl libvulkan-dev \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -60,11 +60,3 @@ end_of_line = unset
|
||||
charset = unset
|
||||
trim_trailing_whitespace = unset
|
||||
insert_final_newline = unset
|
||||
|
||||
[benches/**]
|
||||
indent_style = unset
|
||||
indent_size = unset
|
||||
end_of_line = unset
|
||||
charset = unset
|
||||
trim_trailing_whitespace = unset
|
||||
insert_final_newline = unset
|
||||
|
||||
2
.github/copilot-instructions.md
vendored
2
.github/copilot-instructions.md
vendored
@@ -9,7 +9,7 @@ llama.cpp is a large-scale C/C++ project for efficient LLM (Large Language Model
|
||||
- **Size**: ~200k+ lines of code across 1000+ files
|
||||
- **Architecture**: Modular design with main library (`libllama`) and 40+ executable tools/examples
|
||||
- **Core dependency**: ggml tensor library (vendored in `ggml/` directory)
|
||||
- **Backends supported**: CPU (AVX/NEON/RVV optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
|
||||
- **Backends supported**: CPU (AVX/NEON optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
|
||||
- **License**: MIT
|
||||
|
||||
## Build Instructions
|
||||
|
||||
52
.github/workflows/build-amd.yml
vendored
Normal file
52
.github/workflows/build-amd.yml
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
name: CI (AMD)
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths: [
|
||||
'.github/workflows/build-amd.yml',
|
||||
'**/CMakeLists.txt',
|
||||
'**/.cmake',
|
||||
'**/*.h',
|
||||
'**/*.hpp',
|
||||
'**/*.c',
|
||||
'**/*.cpp',
|
||||
'**/*.cu',
|
||||
'**/*.cuh',
|
||||
'**/*.comp'
|
||||
]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
ggml-ci-x64-amd-vulkan:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-rocm:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
amd-smi static
|
||||
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
244
.github/workflows/build.yml
vendored
244
.github/workflows/build.yml
vendored
@@ -69,6 +69,13 @@ jobs:
|
||||
key: macOS-latest-cmake-arm64
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
brew install curl
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
@@ -76,8 +83,6 @@ jobs:
|
||||
cmake -B build \
|
||||
-DCMAKE_BUILD_RPATH="@loader_path" \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_BUILD_BORINGSSL=ON \
|
||||
-DGGML_METAL_USE_BF16=ON \
|
||||
-DGGML_METAL_EMBED_LIBRARY=OFF \
|
||||
-DGGML_METAL_SHADER_DEBUG=ON \
|
||||
@@ -105,6 +110,13 @@ jobs:
|
||||
key: macOS-latest-cmake-x64
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
brew install curl
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
@@ -114,8 +126,6 @@ jobs:
|
||||
cmake -B build \
|
||||
-DCMAKE_BUILD_RPATH="@loader_path" \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_BUILD_BORINGSSL=ON \
|
||||
-DGGML_METAL=OFF \
|
||||
-DGGML_RPC=ON \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=13.3
|
||||
@@ -141,19 +151,25 @@ jobs:
|
||||
key: macOS-latest-cmake-arm64-webgpu
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
brew install curl
|
||||
|
||||
- name: Dawn Dependency
|
||||
id: dawn-depends
|
||||
run: |
|
||||
DAWN_VERSION="v2.0.0"
|
||||
DAWN_VERSION="v1.0.0"
|
||||
DAWN_OWNER="reeselevine"
|
||||
DAWN_REPO="dawn"
|
||||
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip"
|
||||
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz"
|
||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||
curl -L -o artifact.zip \
|
||||
curl -L -o artifact.tar.gz \
|
||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||
mkdir dawn
|
||||
unzip artifact.zip
|
||||
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1
|
||||
tar -xvf artifact.tar.gz -C dawn --strip-components=1
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -200,7 +216,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
python3 python3-pip python3-dev \
|
||||
libjpeg-dev build-essential libssl-dev \
|
||||
libjpeg-dev build-essential libcurl4-openssl-dev \
|
||||
git-lfs
|
||||
|
||||
- name: Python Dependencies
|
||||
@@ -221,8 +237,6 @@ jobs:
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
@@ -279,15 +293,13 @@ jobs:
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential libssl-dev
|
||||
sudo apt-get install build-essential libcurl4-openssl-dev
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
if: ${{ matrix.sanitizer != 'THREAD' }}
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
|
||||
@@ -298,8 +310,6 @@ jobs:
|
||||
if: ${{ matrix.sanitizer == 'THREAD' }}
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||
@@ -324,7 +334,7 @@ jobs:
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential libssl-dev
|
||||
sudo apt-get install build-essential libcurl4-openssl-dev
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -332,8 +342,6 @@ jobs:
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_LLGUIDANCE=ON
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
@@ -364,14 +372,12 @@ jobs:
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential libssl-dev
|
||||
sudo apt-get install build-essential libcurl4-openssl-dev
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
@@ -398,14 +404,12 @@ jobs:
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get install -y glslc libvulkan-dev libssl-dev
|
||||
sudo apt-get install -y glslc libvulkan-dev libcurl4-openssl-dev
|
||||
|
||||
- name: Configure
|
||||
id: cmake_configure
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||
@@ -435,7 +439,7 @@ jobs:
|
||||
run: |
|
||||
sudo add-apt-repository -y ppa:kisak/kisak-mesa
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libssl-dev
|
||||
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev
|
||||
|
||||
- name: Get latest Vulkan SDK version
|
||||
id: vulkan_sdk_version
|
||||
@@ -461,8 +465,6 @@ jobs:
|
||||
run: |
|
||||
source ./vulkan_sdk/setup-env.sh
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DGGML_VULKAN=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
@@ -494,7 +496,7 @@ jobs:
|
||||
run: |
|
||||
sudo add-apt-repository -y ppa:kisak/kisak-mesa
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libssl-dev
|
||||
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev
|
||||
|
||||
- name: Get latest Vulkan SDK version
|
||||
id: vulkan_sdk_version
|
||||
@@ -519,25 +521,21 @@ jobs:
|
||||
id: dawn-depends
|
||||
run: |
|
||||
sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
|
||||
DAWN_VERSION="v2.0.0"
|
||||
DAWN_VERSION="v1.0.0"
|
||||
DAWN_OWNER="reeselevine"
|
||||
DAWN_REPO="dawn"
|
||||
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip"
|
||||
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz"
|
||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||
curl -L -o artifact.zip \
|
||||
curl -L -o artifact.tar.gz \
|
||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||
mkdir dawn
|
||||
unzip artifact.zip
|
||||
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1
|
||||
tar -xvf artifact.tar.gz -C dawn --strip-components=1
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
export Dawn_DIR=dawn/lib64/cmake/Dawn
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DGGML_WEBGPU=ON
|
||||
cmake -B build -DGGML_WEBGPU=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
- name: Test
|
||||
@@ -560,7 +558,7 @@ jobs:
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libssl-dev rocwmma-dev
|
||||
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libcurl4-openssl-dev rocwmma-dev
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -572,8 +570,6 @@ jobs:
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON \
|
||||
-DGGML_HIP=ON
|
||||
@@ -592,7 +588,7 @@ jobs:
|
||||
id: depends
|
||||
run: |
|
||||
apt-get update
|
||||
apt-get install -y build-essential git cmake libssl-dev
|
||||
apt-get install -y build-essential git cmake libcurl4-openssl-dev
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -604,8 +600,6 @@ jobs:
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DGGML_MUSA=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
@@ -630,7 +624,7 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp libssl-dev
|
||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp libcurl4-openssl-dev
|
||||
|
||||
- name: install oneAPI MKL library
|
||||
shell: bash
|
||||
@@ -652,8 +646,6 @@ jobs:
|
||||
run: |
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
@@ -680,7 +672,7 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp libssl-dev
|
||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp libcurl4-openssl-dev
|
||||
|
||||
- name: install oneAPI MKL library
|
||||
shell: bash
|
||||
@@ -702,8 +694,6 @@ jobs:
|
||||
run: |
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
cmake -B build \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
@@ -730,6 +720,12 @@ jobs:
|
||||
key: macOS-latest-cmake-ios
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
@@ -761,6 +757,12 @@ jobs:
|
||||
key: macOS-latest-cmake-tvos
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
@@ -786,6 +788,12 @@ jobs:
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
@@ -828,6 +836,12 @@ jobs:
|
||||
name: llama-xcframework
|
||||
path: build-apple/llama.xcframework/
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
|
||||
- name: Build llama.cpp with CMake
|
||||
id: cmake_build
|
||||
run: |
|
||||
@@ -979,12 +993,21 @@ jobs:
|
||||
-DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release"
|
||||
cmake --build build-arm64-release --target install --config release
|
||||
|
||||
- name: libCURL
|
||||
id: get_libcurl
|
||||
uses: ./.github/actions/windows-setup-curl
|
||||
with:
|
||||
architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }}
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
cmake -S . -B build ${{ matrix.defines }} `
|
||||
-DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON
|
||||
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include"
|
||||
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS}
|
||||
cp $env:CURL_PATH/bin/libcurl-*.dll build/bin/Release
|
||||
|
||||
- name: Add libopenblas.dll
|
||||
id: add_libopenblas_dll
|
||||
@@ -1028,7 +1051,7 @@ jobs:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
run: |
|
||||
apt update
|
||||
apt install -y cmake build-essential ninja-build libgomp1 git libssl-dev
|
||||
apt install -y cmake build-essential ninja-build libgomp1 git libcurl4-openssl-dev
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1039,12 +1062,10 @@ jobs:
|
||||
- name: Build with CMake
|
||||
run: |
|
||||
cmake -S . -B build -G Ninja \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_CUDA_ARCHITECTURES=89-real \
|
||||
-DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DGGML_CUDA=ON
|
||||
cmake --build build
|
||||
@@ -1078,20 +1099,25 @@ jobs:
|
||||
run: |
|
||||
choco install ninja
|
||||
|
||||
- name: libCURL
|
||||
id: get_libcurl
|
||||
uses: ./.github/actions/windows-setup-curl
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
shell: cmd
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64
|
||||
cmake -S . -B build -G "Ninja Multi-Config" ^
|
||||
-DLLAMA_BUILD_SERVER=ON ^
|
||||
-DLLAMA_CURL=OFF ^
|
||||
-DLLAMA_BUILD_BORINGSSL=ON ^
|
||||
-DGGML_NATIVE=OFF ^
|
||||
-DGGML_BACKEND_DL=ON ^
|
||||
-DGGML_CPU_ALL_VARIANTS=ON ^
|
||||
-DGGML_CUDA=ON ^
|
||||
-DGGML_RPC=ON
|
||||
-DGGML_RPC=ON ^
|
||||
-DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include"
|
||||
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
|
||||
cmake --build build --config Release -j %NINJA_JOBS% -t ggml
|
||||
cmake --build build --config Release
|
||||
@@ -1123,7 +1149,7 @@ jobs:
|
||||
run: |
|
||||
scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL
|
||||
|
||||
# TODO: add ssl support ; we will also need to modify win-build-sycl.bat to accept user-specified args
|
||||
# TODO: add libcurl support ; we will also need to modify win-build-sycl.bat to accept user-specified args
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -1180,8 +1206,14 @@ jobs:
|
||||
key: ${{ github.job }}
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: libCURL
|
||||
id: get_libcurl
|
||||
uses: ./.github/actions/windows-setup-curl
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
|
||||
@@ -1190,12 +1222,11 @@ jobs:
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-${{ env.ROCM_VERSION }}/include/" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DLLAMA_CURL=OFF `
|
||||
-DLLAMA_BUILD_BORINGSSL=ON `
|
||||
-DROCM_DIR="${env:HIP_PATH}" `
|
||||
-DGGML_HIP=ON `
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON `
|
||||
-DGGML_RPC=ON
|
||||
-DGGML_RPC=ON `
|
||||
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include"
|
||||
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
|
||||
|
||||
ios-xcode-build:
|
||||
@@ -1357,10 +1388,14 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: [x86, aarch64]
|
||||
chip_type: ['910b', '310p']
|
||||
build: ['Release']
|
||||
cann:
|
||||
- '8.1.RC1.alpha001-910b-openeuler22.03-py3.10'
|
||||
device:
|
||||
- 'ascend910b3'
|
||||
build:
|
||||
- 'Release'
|
||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
|
||||
container: ascendai/cann:${{ matrix.cann }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -1377,7 +1412,7 @@ jobs:
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
||||
-DGGML_CANN=on \
|
||||
-DSOC_TYPE=ascend${{ matrix.chip_type }}
|
||||
-DSOC_TYPE=${{ matrix.device }}
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
# TODO: simplify the following workflows using a matrix
|
||||
@@ -1562,34 +1597,6 @@ jobs:
|
||||
run: |
|
||||
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-vulkan:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-rocm:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
amd-smi static
|
||||
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-mac-metal:
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
@@ -1642,50 +1649,3 @@ jobs:
|
||||
run: |
|
||||
GG_BUILD_KLEIDIAI=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
|
||||
ggml-ci-arm64-graviton4-kleidiai:
|
||||
runs-on: ah-ubuntu_22_04-c8g_8x
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
set -euxo pipefail
|
||||
sudo apt-get update
|
||||
sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a \
|
||||
apt-get install -y \
|
||||
build-essential \
|
||||
libcurl4-openssl-dev \
|
||||
python3-venv \
|
||||
gpg \
|
||||
wget \
|
||||
time \
|
||||
git-lfs
|
||||
|
||||
git lfs install
|
||||
|
||||
# install the latest cmake
|
||||
sudo install -d /usr/share/keyrings
|
||||
wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc \
|
||||
| gpg --dearmor \
|
||||
| sudo tee /usr/share/keyrings/kitware-archive-keyring.gpg >/dev/null
|
||||
echo 'deb [signed-by=/usr/share/keyrings/kitware-archive-keyring.gpg] https://apt.kitware.com/ubuntu/ jammy main' \
|
||||
| sudo tee /etc/apt/sources.list.d/kitware.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y cmake
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ggml-ci-arm64-graviton4-kleidiai
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
GG_BUILD_KLEIDIAI=1 \
|
||||
GG_BUILD_EXTRA_TESTS_0=1 \
|
||||
bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
|
||||
52
.github/workflows/check-vendor.yml
vendored
52
.github/workflows/check-vendor.yml
vendored
@@ -1,52 +0,0 @@
|
||||
name: Check vendor
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths: [
|
||||
'vendor/**',
|
||||
'scripts/sync_vendor.py'
|
||||
]
|
||||
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
paths: [
|
||||
'vendor/**',
|
||||
'scripts/sync_vendor.py'
|
||||
]
|
||||
|
||||
jobs:
|
||||
check-vendor:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
- name: Run vendor sync
|
||||
run: |
|
||||
set -euo pipefail
|
||||
python3 scripts/sync_vendor.py
|
||||
|
||||
- name: Check for changes
|
||||
run: |
|
||||
set -euo pipefail
|
||||
# detect modified or untracked files
|
||||
changed=$(git status --porcelain --untracked-files=all || true)
|
||||
if [ -n "$changed" ]; then
|
||||
echo "Vendor sync modified files:"
|
||||
echo "$changed" | awk '{ print $2 }' | sed '/^$/d'
|
||||
echo "Failing because vendor files mismatch. Please update scripts/sync_vendor.py"
|
||||
exit 1
|
||||
else
|
||||
echo "Vendor files are up-to-date."
|
||||
fi
|
||||
46
.github/workflows/release.yml
vendored
46
.github/workflows/release.yml
vendored
@@ -693,51 +693,6 @@ jobs:
|
||||
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
|
||||
name: llama-${{ steps.tag.outputs.name }}-xcframework
|
||||
|
||||
openEuler-cann:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: [x86, aarch64]
|
||||
chip_type: ['910b', '310p']
|
||||
build: ['Release']
|
||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
yum update -y
|
||||
yum install -y git gcc gcc-c++ make cmake libcurl-devel
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH}
|
||||
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
||||
-DGGML_CANN=on \
|
||||
-DSOC_TYPE=ascend${{ matrix.chip_type }}
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
- name: Determine tag name
|
||||
id: tag
|
||||
uses: ./.github/actions/get-tag-name
|
||||
|
||||
- name: Pack artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
zip -r llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip ./build/bin/*
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||
|
||||
release:
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
|
||||
@@ -759,7 +714,6 @@ jobs:
|
||||
- macOS-arm64
|
||||
- macOS-x64
|
||||
- ios-xcode-build
|
||||
- openEuler-cann
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
27
.github/workflows/server.yml
vendored
27
.github/workflows/server.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
||||
curl \
|
||||
wget \
|
||||
language-pack-en \
|
||||
libssl-dev
|
||||
libcurl4-openssl-dev
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
@@ -209,7 +209,7 @@ jobs:
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run UI tests
|
||||
run: npm run test:ui -- --testTimeout=60000
|
||||
run: npm run test:ui
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run E2E tests
|
||||
@@ -242,7 +242,7 @@ jobs:
|
||||
curl \
|
||||
wget \
|
||||
language-pack-en \
|
||||
libssl-dev
|
||||
libcurl4-openssl-dev
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
@@ -283,8 +283,6 @@ jobs:
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_BUILD_SERVER=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||
@@ -297,8 +295,6 @@ jobs:
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_BUILD_SERVER=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
|
||||
@@ -310,8 +306,6 @@ jobs:
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_OPENSSL=ON \
|
||||
-DLLAMA_BUILD_SERVER=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ;
|
||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
||||
@@ -351,10 +345,16 @@ jobs:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: libCURL
|
||||
id: get_libcurl
|
||||
uses: ./.github/actions/windows-setup-curl
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
cmake -B build -DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON
|
||||
cmake -B build -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include"
|
||||
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
||||
|
||||
- name: Python setup
|
||||
@@ -368,6 +368,13 @@ jobs:
|
||||
run: |
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
|
||||
- name: Copy Libcurl
|
||||
id: prepare_libcurl
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
cp $env:CURL_PATH/bin/libcurl-x64.dll ./build/bin/Release/libcurl-x64.dll
|
||||
|
||||
- name: Tests
|
||||
id: server_integration_tests
|
||||
if: ${{ !matrix.disabled_on_pr || !github.event.pull_request }}
|
||||
|
||||
108
.gitignore
vendored
108
.gitignore
vendored
@@ -20,40 +20,52 @@
|
||||
*.so
|
||||
*.swp
|
||||
*.tmp
|
||||
*.DS_Store
|
||||
|
||||
# IDE / OS
|
||||
|
||||
/.cache/
|
||||
/.ccls-cache/
|
||||
/.direnv/
|
||||
/.envrc
|
||||
/.idea/
|
||||
/.swiftpm
|
||||
/.vs/
|
||||
/.vscode/
|
||||
/nppBackup
|
||||
.cache/
|
||||
.ccls-cache/
|
||||
.direnv/
|
||||
.DS_Store
|
||||
.envrc
|
||||
.idea/
|
||||
.swiftpm
|
||||
.vs/
|
||||
.vscode/
|
||||
nppBackup
|
||||
|
||||
|
||||
# Coverage
|
||||
|
||||
/gcovr-report/
|
||||
/lcov-report/
|
||||
gcovr-report/
|
||||
lcov-report/
|
||||
|
||||
# Build Artifacts
|
||||
|
||||
/tags
|
||||
/.build/
|
||||
/build*
|
||||
/release
|
||||
/debug
|
||||
tags
|
||||
.build/
|
||||
build*
|
||||
release
|
||||
debug
|
||||
!build-info.cmake
|
||||
!build-info.cpp.in
|
||||
!build-info.sh
|
||||
!build.zig
|
||||
!docs/build.md
|
||||
/libllama.so
|
||||
/llama-*
|
||||
/vulkan-shaders-gen
|
||||
android-ndk-*
|
||||
arm_neon.h
|
||||
cmake-build-*
|
||||
CMakeSettings.json
|
||||
compile_commands.json
|
||||
ggml-metal-embed.metal
|
||||
llama-batched-swift
|
||||
/rpc-server
|
||||
/out/
|
||||
/tmp/
|
||||
/autogen-*.md
|
||||
out/
|
||||
tmp/
|
||||
autogen-*.md
|
||||
|
||||
# Deprecated
|
||||
|
||||
@@ -62,38 +74,44 @@
|
||||
|
||||
# CI
|
||||
|
||||
!/.github/workflows/*.yml
|
||||
!.github/workflows/*.yml
|
||||
|
||||
# Models
|
||||
|
||||
/models/*
|
||||
/models-mnt
|
||||
!/models/.editorconfig
|
||||
!/models/ggml-vocab-*.gguf*
|
||||
!/models/templates
|
||||
models/*
|
||||
models-mnt
|
||||
!models/.editorconfig
|
||||
!models/ggml-vocab-*.gguf*
|
||||
!models/templates
|
||||
|
||||
# Zig
|
||||
/zig-out/
|
||||
/zig-cache/
|
||||
zig-out/
|
||||
zig-cache/
|
||||
|
||||
# Logs
|
||||
|
||||
ppl-*.txt
|
||||
qnt-*.txt
|
||||
perf-*.txt
|
||||
|
||||
# Examples
|
||||
|
||||
/examples/jeopardy/results.txt
|
||||
/tools/server/*.css.hpp
|
||||
/tools/server/*.html.hpp
|
||||
/tools/server/*.js.hpp
|
||||
/tools/server/*.mjs.hpp
|
||||
/tools/server/*.gz.hpp
|
||||
!/build_64.sh
|
||||
!/examples/*.bat
|
||||
!/examples/*/*.kts
|
||||
!/examples/*/*/*.kts
|
||||
!/examples/sycl/*.bat
|
||||
!/examples/sycl/*.sh
|
||||
examples/jeopardy/results.txt
|
||||
tools/server/*.css.hpp
|
||||
tools/server/*.html.hpp
|
||||
tools/server/*.js.hpp
|
||||
tools/server/*.mjs.hpp
|
||||
tools/server/*.gz.hpp
|
||||
!build_64.sh
|
||||
!examples/*.bat
|
||||
!examples/*/*.kts
|
||||
!examples/*/*/*.kts
|
||||
!examples/sycl/*.bat
|
||||
!examples/sycl/*.sh
|
||||
|
||||
# Server Web UI temporary files
|
||||
/tools/server/webui/node_modules
|
||||
/tools/server/webui/dist
|
||||
node_modules
|
||||
tools/server/webui/dist
|
||||
|
||||
# Python
|
||||
|
||||
@@ -129,8 +147,8 @@ poetry.toml
|
||||
# Local scripts
|
||||
/run-vim.sh
|
||||
/run-chat.sh
|
||||
/.ccache/
|
||||
.ccache/
|
||||
|
||||
# IDE
|
||||
/*.code-workspace
|
||||
/.windsurf/
|
||||
*.code-workspace
|
||||
.windsurf/
|
||||
|
||||
@@ -92,7 +92,6 @@ option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_
|
||||
|
||||
# 3rd party libs
|
||||
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
|
||||
option(LLAMA_HTTPLIB "llama: if libcurl is disabled, use httplib to download model from an URL" ON)
|
||||
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" OFF)
|
||||
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
|
||||
|
||||
@@ -201,9 +200,6 @@ endif()
|
||||
|
||||
if (LLAMA_BUILD_COMMON)
|
||||
add_subdirectory(common)
|
||||
if (LLAMA_HTTPLIB)
|
||||
add_subdirectory(vendor/cpp-httplib)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
||||
|
||||
@@ -61,7 +61,6 @@ range of hardware - locally and in the cloud.
|
||||
- Plain C/C++ implementation without any dependencies
|
||||
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
|
||||
- AVX, AVX2, AVX512 and AMX support for x86 architectures
|
||||
- RVV, ZVFH, ZFH and ZICBOP support for RISC-V architectures
|
||||
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
|
||||
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA)
|
||||
- Vulkan and SYCL backend support
|
||||
@@ -242,7 +241,6 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
|
||||
- [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage
|
||||
- [Styled Lines](https://marketplace.unity.com/packages/tools/generative-ai/styled-lines-llama-cpp-model-292902) (proprietary licensed, async wrapper of inference part for game development in Unity3d with pre-built Mobile and Web platform wrappers and a model example)
|
||||
- [unslothai/unsloth](https://github.com/unslothai/unsloth) – 🦥 exports/saves fine-tuned and trained models to GGUF (Apache-2.0)
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +0,0 @@
|
||||
{
|
||||
"chars": 2296.1916666666666,
|
||||
"chars:std": 986.051306946325,
|
||||
"score": 0.925,
|
||||
"score:std": 0.26339134382131846
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -1,264 +0,0 @@
|
||||
## System info
|
||||
|
||||
```bash
|
||||
uname --all
|
||||
Linux spark-17ed 6.11.0-1016-nvidia #16-Ubuntu SMP PREEMPT_DYNAMIC Sun Sep 21 16:52:46 UTC 2025 aarch64 aarch64 aarch64 GNU/Linux
|
||||
|
||||
g++ --version
|
||||
g++ (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
|
||||
|
||||
nvidia-smi
|
||||
Sun Nov 2 10:43:25 2025
|
||||
+-----------------------------------------------------------------------------------------+
|
||||
| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 |
|
||||
+-----------------------------------------+------------------------+----------------------+
|
||||
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
|
||||
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
|
||||
| | | MIG M. |
|
||||
|=========================================+========================+======================|
|
||||
| 0 NVIDIA GB10 On | 0000000F:01:00.0 Off | N/A |
|
||||
| N/A 35C P8 4W / N/A | Not Supported | 0% Default |
|
||||
| | | N/A |
|
||||
+-----------------------------------------+------------------------+----------------------+
|
||||
```
|
||||
|
||||
## ggml-org/gpt-oss-20b-GGUF
|
||||
|
||||
Model: https://huggingface.co/ggml-org/gpt-oss-20b-GGUF
|
||||
|
||||
- `llama-batched-bench`
|
||||
|
||||
|
||||
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
|
||||
|
||||
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|
||||
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|
||||
| 512 | 32 | 1 | 544 | 0.374 | 1369.01 | 0.383 | 83.64 | 0.757 | 719.01 |
|
||||
| 512 | 32 | 2 | 1088 | 0.274 | 3741.35 | 0.659 | 97.14 | 0.933 | 1166.66 |
|
||||
| 512 | 32 | 4 | 2176 | 0.526 | 3896.47 | 0.817 | 156.73 | 1.342 | 1621.08 |
|
||||
| 512 | 32 | 8 | 4352 | 1.044 | 3925.10 | 0.987 | 259.44 | 2.030 | 2143.56 |
|
||||
| 512 | 32 | 16 | 8704 | 2.076 | 3945.84 | 1.248 | 410.32 | 3.324 | 2618.60 |
|
||||
| 512 | 32 | 32 | 17408 | 4.170 | 3929.28 | 1.630 | 628.40 | 5.799 | 3001.76 |
|
||||
| 4096 | 32 | 1 | 4128 | 1.083 | 3782.66 | 0.394 | 81.21 | 1.477 | 2795.13 |
|
||||
| 4096 | 32 | 2 | 8256 | 2.166 | 3782.72 | 0.725 | 88.28 | 2.891 | 2856.14 |
|
||||
| 4096 | 32 | 4 | 16512 | 4.333 | 3780.88 | 0.896 | 142.82 | 5.230 | 3157.38 |
|
||||
| 4096 | 32 | 8 | 33024 | 8.618 | 3802.14 | 1.155 | 221.69 | 9.773 | 3379.08 |
|
||||
| 4096 | 32 | 16 | 66048 | 17.330 | 3781.73 | 1.598 | 320.34 | 18.928 | 3489.45 |
|
||||
| 4096 | 32 | 32 | 132096 | 34.671 | 3780.48 | 2.336 | 438.35 | 37.007 | 3569.51 |
|
||||
| 8192 | 32 | 1 | 8224 | 2.233 | 3668.56 | 0.438 | 72.98 | 2.671 | 3078.44 |
|
||||
| 8192 | 32 | 2 | 16448 | 4.425 | 3702.95 | 0.756 | 84.66 | 5.181 | 3174.95 |
|
||||
| 8192 | 32 | 4 | 32896 | 8.859 | 3698.64 | 0.967 | 132.38 | 9.826 | 3347.72 |
|
||||
| 8192 | 32 | 8 | 65792 | 17.714 | 3699.57 | 1.277 | 200.52 | 18.991 | 3464.35 |
|
||||
| 8192 | 32 | 16 | 131584 | 35.494 | 3692.84 | 1.841 | 278.12 | 37.335 | 3524.46 |
|
||||
| 8192 | 32 | 32 | 263168 | 70.949 | 3694.82 | 2.798 | 365.99 | 73.747 | 3568.53 |
|
||||
|
||||
|
||||
- `llama-bench`
|
||||
|
||||
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
|
||||
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
|
||||
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 3714.25 ± 20.36 |
|
||||
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 86.58 ± 0.43 |
|
||||
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 3445.17 ± 17.85 |
|
||||
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 81.72 ± 0.53 |
|
||||
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 3218.78 ± 11.34 |
|
||||
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 74.86 ± 0.64 |
|
||||
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 2732.83 ± 7.17 |
|
||||
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 71.57 ± 0.51 |
|
||||
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 2119.75 ± 12.81 |
|
||||
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 62.33 ± 0.24 |
|
||||
|
||||
build: eeee367de (6989)
|
||||
|
||||
## ggml-org/gpt-oss-120b-GGUF
|
||||
|
||||
Model: https://huggingface.co/ggml-org/gpt-oss-120b-GGUF
|
||||
|
||||
- `llama-batched-bench`
|
||||
|
||||
|
||||
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
|
||||
|
||||
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|
||||
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|
||||
| 512 | 32 | 1 | 544 | 0.571 | 897.18 | 0.543 | 58.96 | 1.113 | 488.60 |
|
||||
| 512 | 32 | 2 | 1088 | 0.593 | 1725.37 | 1.041 | 61.45 | 1.635 | 665.48 |
|
||||
| 512 | 32 | 4 | 2176 | 1.043 | 1963.15 | 1.334 | 95.95 | 2.377 | 915.36 |
|
||||
| 512 | 32 | 8 | 4352 | 2.099 | 1951.63 | 1.717 | 149.07 | 3.816 | 1140.45 |
|
||||
| 512 | 32 | 16 | 8704 | 4.207 | 1947.12 | 2.311 | 221.56 | 6.518 | 1335.35 |
|
||||
| 512 | 32 | 32 | 17408 | 8.422 | 1945.36 | 3.298 | 310.46 | 11.720 | 1485.27 |
|
||||
| 4096 | 32 | 1 | 4128 | 2.138 | 1915.88 | 0.571 | 56.09 | 2.708 | 1524.12 |
|
||||
| 4096 | 32 | 2 | 8256 | 4.266 | 1920.25 | 1.137 | 56.27 | 5.404 | 1527.90 |
|
||||
| 4096 | 32 | 4 | 16512 | 8.564 | 1913.02 | 1.471 | 86.99 | 10.036 | 1645.29 |
|
||||
| 4096 | 32 | 8 | 33024 | 17.092 | 1917.19 | 1.979 | 129.33 | 19.071 | 1731.63 |
|
||||
| 4096 | 32 | 16 | 66048 | 34.211 | 1915.65 | 2.850 | 179.66 | 37.061 | 1782.15 |
|
||||
| 4096 | 32 | 32 | 132096 | 68.394 | 1916.44 | 4.381 | 233.72 | 72.775 | 1815.13 |
|
||||
| 8192 | 32 | 1 | 8224 | 4.349 | 1883.45 | 0.620 | 51.65 | 4.969 | 1655.04 |
|
||||
| 8192 | 32 | 2 | 16448 | 8.674 | 1888.83 | 1.178 | 54.33 | 9.852 | 1669.48 |
|
||||
| 8192 | 32 | 4 | 32896 | 17.351 | 1888.55 | 1.580 | 81.01 | 18.931 | 1737.68 |
|
||||
| 8192 | 32 | 8 | 65792 | 34.743 | 1886.31 | 2.173 | 117.80 | 36.916 | 1782.20 |
|
||||
| 8192 | 32 | 16 | 131584 | 69.413 | 1888.29 | 3.297 | 155.28 | 72.710 | 1809.70 |
|
||||
| 8192 | 32 | 32 | 263168 | 138.903 | 1887.24 | 5.004 | 204.63 | 143.907 | 1828.73 |
|
||||
|
||||
|
||||
- `llama-bench`
|
||||
|
||||
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
|
||||
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
|
||||
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 1919.36 ± 5.01 |
|
||||
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 60.40 ± 0.30 |
|
||||
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 1825.30 ± 6.37 |
|
||||
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 56.94 ± 0.29 |
|
||||
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 1739.19 ± 6.00 |
|
||||
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 52.51 ± 0.42 |
|
||||
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 1536.75 ± 4.27 |
|
||||
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 49.33 ± 0.27 |
|
||||
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 1255.85 ± 3.26 |
|
||||
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 42.99 ± 0.18 |
|
||||
|
||||
build: eeee367de (6989)
|
||||
|
||||
## ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF
|
||||
|
||||
Model: https://huggingface.co/ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF
|
||||
|
||||
- `llama-batched-bench`
|
||||
|
||||
|
||||
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
|
||||
|
||||
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|
||||
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|
||||
| 512 | 32 | 1 | 544 | 0.398 | 1285.90 | 0.530 | 60.41 | 0.928 | 586.27 |
|
||||
| 512 | 32 | 2 | 1088 | 0.386 | 2651.65 | 0.948 | 67.50 | 1.334 | 815.38 |
|
||||
| 512 | 32 | 4 | 2176 | 0.666 | 3076.37 | 1.209 | 105.87 | 1.875 | 1160.71 |
|
||||
| 512 | 32 | 8 | 4352 | 1.325 | 3091.39 | 1.610 | 158.98 | 2.935 | 1482.65 |
|
||||
| 512 | 32 | 16 | 8704 | 2.664 | 3075.58 | 2.150 | 238.19 | 4.813 | 1808.39 |
|
||||
| 512 | 32 | 32 | 17408 | 5.336 | 3070.31 | 2.904 | 352.59 | 8.240 | 2112.50 |
|
||||
| 4096 | 32 | 1 | 4128 | 1.444 | 2836.81 | 0.581 | 55.09 | 2.025 | 2038.81 |
|
||||
| 4096 | 32 | 2 | 8256 | 2.872 | 2852.14 | 1.084 | 59.06 | 3.956 | 2086.99 |
|
||||
| 4096 | 32 | 4 | 16512 | 5.744 | 2852.32 | 1.440 | 88.90 | 7.184 | 2298.47 |
|
||||
| 4096 | 32 | 8 | 33024 | 11.463 | 2858.68 | 2.068 | 123.78 | 13.531 | 2440.65 |
|
||||
| 4096 | 32 | 16 | 66048 | 22.915 | 2859.95 | 3.018 | 169.67 | 25.933 | 2546.90 |
|
||||
| 4096 | 32 | 32 | 132096 | 45.956 | 2852.10 | 4.609 | 222.18 | 50.565 | 2612.39 |
|
||||
| 8192 | 32 | 1 | 8224 | 3.063 | 2674.72 | 0.693 | 46.20 | 3.755 | 2189.92 |
|
||||
| 8192 | 32 | 2 | 16448 | 6.109 | 2681.87 | 1.214 | 52.71 | 7.323 | 2245.98 |
|
||||
| 8192 | 32 | 4 | 32896 | 12.197 | 2686.63 | 1.682 | 76.11 | 13.878 | 2370.30 |
|
||||
| 8192 | 32 | 8 | 65792 | 24.409 | 2684.94 | 2.556 | 100.17 | 26.965 | 2439.95 |
|
||||
| 8192 | 32 | 16 | 131584 | 48.753 | 2688.50 | 3.994 | 128.20 | 52.747 | 2494.64 |
|
||||
| 8192 | 32 | 32 | 263168 | 97.508 | 2688.42 | 6.528 | 156.86 | 104.037 | 2529.57 |
|
||||
|
||||
|
||||
- `llama-bench`
|
||||
|
||||
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
|
||||
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
|
||||
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 2925.55 ± 4.25 |
|
||||
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 62.80 ± 0.27 |
|
||||
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 2531.01 ± 6.79 |
|
||||
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 55.86 ± 0.33 |
|
||||
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 2244.39 ± 5.33 |
|
||||
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 45.95 ± 0.33 |
|
||||
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 1783.17 ± 3.68 |
|
||||
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 39.07 ± 0.10 |
|
||||
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 1241.90 ± 3.13 |
|
||||
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 29.92 ± 0.06 |
|
||||
|
||||
build: eeee367de (6989)
|
||||
|
||||
## ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF
|
||||
|
||||
Model: https://huggingface.co/ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF
|
||||
|
||||
- `llama-batched-bench`
|
||||
|
||||
|
||||
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
|
||||
|
||||
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|
||||
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|
||||
| 512 | 32 | 1 | 544 | 0.211 | 2421.57 | 1.055 | 30.33 | 1.266 | 429.57 |
|
||||
| 512 | 32 | 2 | 1088 | 0.419 | 2441.34 | 1.130 | 56.65 | 1.549 | 702.32 |
|
||||
| 512 | 32 | 4 | 2176 | 0.873 | 2345.54 | 1.174 | 108.99 | 2.048 | 1062.74 |
|
||||
| 512 | 32 | 8 | 4352 | 1.727 | 2371.85 | 1.254 | 204.22 | 2.980 | 1460.19 |
|
||||
| 512 | 32 | 16 | 8704 | 3.452 | 2373.22 | 1.492 | 343.16 | 4.944 | 1760.56 |
|
||||
| 512 | 32 | 32 | 17408 | 6.916 | 2368.93 | 1.675 | 611.51 | 8.591 | 2026.36 |
|
||||
| 4096 | 32 | 1 | 4128 | 1.799 | 2277.26 | 1.084 | 29.51 | 2.883 | 1431.91 |
|
||||
| 4096 | 32 | 2 | 8256 | 3.577 | 2290.01 | 1.196 | 53.50 | 4.774 | 1729.51 |
|
||||
| 4096 | 32 | 4 | 16512 | 7.172 | 2284.36 | 1.313 | 97.50 | 8.485 | 1946.00 |
|
||||
| 4096 | 32 | 8 | 33024 | 14.341 | 2284.96 | 1.520 | 168.46 | 15.860 | 2082.18 |
|
||||
| 4096 | 32 | 16 | 66048 | 28.675 | 2285.44 | 1.983 | 258.21 | 30.658 | 2154.33 |
|
||||
| 4096 | 32 | 32 | 132096 | 57.354 | 2285.32 | 2.640 | 387.87 | 59.994 | 2201.82 |
|
||||
| 8192 | 32 | 1 | 8224 | 3.701 | 2213.75 | 1.119 | 28.59 | 4.820 | 1706.34 |
|
||||
| 8192 | 32 | 2 | 16448 | 7.410 | 2211.19 | 1.272 | 50.31 | 8.682 | 1894.56 |
|
||||
| 8192 | 32 | 4 | 32896 | 14.802 | 2213.83 | 1.460 | 87.68 | 16.261 | 2022.96 |
|
||||
| 8192 | 32 | 8 | 65792 | 29.609 | 2213.35 | 1.781 | 143.74 | 31.390 | 2095.93 |
|
||||
| 8192 | 32 | 16 | 131584 | 59.229 | 2212.96 | 2.495 | 205.17 | 61.725 | 2131.79 |
|
||||
| 8192 | 32 | 32 | 263168 | 118.449 | 2213.15 | 3.714 | 275.75 | 122.162 | 2154.25 |
|
||||
|
||||
|
||||
- `llama-bench`
|
||||
|
||||
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
|
||||
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
|
||||
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 2272.74 ± 4.68 |
|
||||
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 30.66 ± 0.02 |
|
||||
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 2107.80 ± 9.55 |
|
||||
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 29.71 ± 0.05 |
|
||||
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 1937.80 ± 6.75 |
|
||||
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 28.86 ± 0.04 |
|
||||
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 1641.12 ± 1.78 |
|
||||
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 27.24 ± 0.04 |
|
||||
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 1296.02 ± 2.67 |
|
||||
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 23.78 ± 0.03 |
|
||||
|
||||
build: eeee367de (6989)
|
||||
|
||||
## ggml-org/gemma-3-4b-it-qat-GGUF
|
||||
|
||||
Model: https://huggingface.co/ggml-org/gemma-3-4b-it-qat-GGUF
|
||||
|
||||
- `llama-batched-bench`
|
||||
|
||||
|
||||
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
|
||||
|
||||
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|
||||
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|
||||
| 512 | 32 | 1 | 544 | 0.094 | 5434.73 | 0.394 | 81.21 | 0.488 | 1114.15 |
|
||||
| 512 | 32 | 2 | 1088 | 0.168 | 6091.68 | 0.498 | 128.52 | 0.666 | 1633.41 |
|
||||
| 512 | 32 | 4 | 2176 | 0.341 | 6010.68 | 0.542 | 236.37 | 0.882 | 2466.43 |
|
||||
| 512 | 32 | 8 | 4352 | 0.665 | 6161.46 | 0.678 | 377.74 | 1.342 | 3241.72 |
|
||||
| 512 | 32 | 16 | 8704 | 1.323 | 6193.19 | 0.902 | 567.41 | 2.225 | 3911.74 |
|
||||
| 512 | 32 | 32 | 17408 | 2.642 | 6202.03 | 1.231 | 832.03 | 3.872 | 4495.36 |
|
||||
| 4096 | 32 | 1 | 4128 | 0.701 | 5840.49 | 0.439 | 72.95 | 1.140 | 3621.23 |
|
||||
| 4096 | 32 | 2 | 8256 | 1.387 | 5906.82 | 0.574 | 111.48 | 1.961 | 4210.12 |
|
||||
| 4096 | 32 | 4 | 16512 | 2.758 | 5940.33 | 0.651 | 196.58 | 3.409 | 4843.33 |
|
||||
| 4096 | 32 | 8 | 33024 | 5.491 | 5967.56 | 0.876 | 292.40 | 6.367 | 5187.12 |
|
||||
| 4096 | 32 | 16 | 66048 | 10.978 | 5969.58 | 1.275 | 401.69 | 12.253 | 5390.38 |
|
||||
| 4096 | 32 | 32 | 132096 | 21.944 | 5972.93 | 1.992 | 514.16 | 23.936 | 5518.73 |
|
||||
| 8192 | 32 | 1 | 8224 | 1.402 | 5841.91 | 0.452 | 70.73 | 1.855 | 4434.12 |
|
||||
| 8192 | 32 | 2 | 16448 | 2.793 | 5865.34 | 0.637 | 100.55 | 3.430 | 4795.51 |
|
||||
| 8192 | 32 | 4 | 32896 | 5.564 | 5889.64 | 0.770 | 166.26 | 6.334 | 5193.95 |
|
||||
| 8192 | 32 | 8 | 65792 | 11.114 | 5896.44 | 1.122 | 228.07 | 12.237 | 5376.51 |
|
||||
| 8192 | 32 | 16 | 131584 | 22.210 | 5901.38 | 1.789 | 286.15 | 24.000 | 5482.74 |
|
||||
| 8192 | 32 | 32 | 263168 | 44.382 | 5906.56 | 3.044 | 336.38 | 47.426 | 5549.02 |
|
||||
|
||||
|
||||
- `llama-bench`
|
||||
|
||||
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
|
||||
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
|
||||
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 5810.04 ± 21.71 |
|
||||
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 84.54 ± 0.18 |
|
||||
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 5288.04 ± 3.54 |
|
||||
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 78.82 ± 1.37 |
|
||||
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 4960.43 ± 16.64 |
|
||||
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 74.13 ± 0.30 |
|
||||
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 4495.92 ± 31.11 |
|
||||
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 72.37 ± 0.29 |
|
||||
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 3746.90 ± 40.01 |
|
||||
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 63.02 ± 0.20 |
|
||||
|
||||
build: eeee367de (6989)
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -454,8 +454,6 @@ cmake -B build-visionos -G Xcode \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_HTTPLIB=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos --config Release -- -quiet
|
||||
|
||||
@@ -470,8 +468,6 @@ cmake -B build-visionos-sim -G Xcode \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DLLAMA_HTTPLIB=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos-sim --config Release -- -quiet
|
||||
|
||||
|
||||
@@ -121,12 +121,7 @@ fi
|
||||
if [ -n "${GG_BUILD_KLEIDIAI}" ]; then
|
||||
echo ">>===== Enabling KleidiAI support"
|
||||
|
||||
CANDIDATES=(
|
||||
"armv9-a+dotprod+i8mm+sve2"
|
||||
"armv9-a+dotprod+i8mm"
|
||||
"armv8.6-a+dotprod+i8mm"
|
||||
"armv8.2-a+dotprod"
|
||||
)
|
||||
CANDIDATES=("armv9-a+dotprod+i8mm" "armv8.6-a+dotprod+i8mm" "armv8.2-a+dotprod")
|
||||
CPU=""
|
||||
|
||||
for cpu in "${CANDIDATES[@]}"; do
|
||||
|
||||
@@ -50,16 +50,12 @@ add_library(${TARGET} STATIC
|
||||
base64.hpp
|
||||
chat-parser.cpp
|
||||
chat-parser.h
|
||||
chat-parser-xml-toolcall.h
|
||||
chat-parser-xml-toolcall.cpp
|
||||
chat.cpp
|
||||
chat.h
|
||||
common.cpp
|
||||
common.h
|
||||
console.cpp
|
||||
console.h
|
||||
download.cpp
|
||||
download.h
|
||||
http.h
|
||||
json-partial.cpp
|
||||
json-partial.h
|
||||
@@ -81,11 +77,10 @@ if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS build_info)
|
||||
|
||||
# Use curl to download model url
|
||||
if (LLAMA_CURL)
|
||||
# Use curl to download model url
|
||||
find_package(CURL)
|
||||
if (NOT CURL_FOUND)
|
||||
message(FATAL_ERROR "Could NOT find CURL. Hint: to disable this feature, set -DLLAMA_CURL=OFF")
|
||||
@@ -93,10 +88,42 @@ if (LLAMA_CURL)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
|
||||
include_directories(${CURL_INCLUDE_DIRS})
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARIES})
|
||||
elseif (LLAMA_HTTPLIB)
|
||||
# otherwise, use cpp-httplib
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
|
||||
endif()
|
||||
|
||||
if (LLAMA_OPENSSL)
|
||||
find_package(OpenSSL)
|
||||
if (OpenSSL_FOUND)
|
||||
include(CheckCSourceCompiles)
|
||||
set(SAVED_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES})
|
||||
set(CMAKE_REQUIRED_INCLUDES ${OPENSSL_INCLUDE_DIR})
|
||||
check_c_source_compiles("
|
||||
#include <openssl/opensslv.h>
|
||||
#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER)
|
||||
# if OPENSSL_VERSION_NUMBER < 0x1010107f
|
||||
# error bad version
|
||||
# endif
|
||||
#else
|
||||
# if OPENSSL_VERSION_NUMBER < 0x30000000L
|
||||
# error bad version
|
||||
# endif
|
||||
#endif
|
||||
int main() { return 0; }
|
||||
" OPENSSL_VERSION_SUPPORTED)
|
||||
set(CMAKE_REQUIRED_INCLUDES ${SAVED_CMAKE_REQUIRED_INCLUDES})
|
||||
if (OPENSSL_VERSION_SUPPORTED)
|
||||
message(STATUS "OpenSSL found: ${OPENSSL_VERSION}")
|
||||
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT)
|
||||
target_link_libraries(${TARGET} PUBLIC OpenSSL::SSL OpenSSL::Crypto)
|
||||
if (APPLE AND CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
||||
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
find_library(CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED)
|
||||
find_library(SECURITY_FRAMEWORK Security REQUIRED)
|
||||
target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK})
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "OpenSSL not found, SSL support disabled")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_LLGUIDANCE)
|
||||
|
||||
1025
common/arg.cpp
1025
common/arg.cpp
File diff suppressed because it is too large
Load Diff
@@ -59,8 +59,8 @@ struct common_arg {
|
||||
common_arg & set_sparam();
|
||||
bool in_example(enum llama_example ex);
|
||||
bool is_exclude(enum llama_example ex);
|
||||
bool get_value_from_env(std::string & output) const;
|
||||
bool has_value_from_env() const;
|
||||
bool get_value_from_env(std::string & output);
|
||||
bool has_value_from_env();
|
||||
std::string to_string();
|
||||
};
|
||||
|
||||
|
||||
@@ -1,861 +0,0 @@
|
||||
#include "chat.h"
|
||||
#include "chat-parser.h"
|
||||
#include "common.h"
|
||||
#include "json-partial.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
class xml_toolcall_syntax_exception : public std::runtime_error {
|
||||
public:
|
||||
xml_toolcall_syntax_exception(const std::string & message) : std::runtime_error(message) {}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
inline void sort_uniq(std::vector<T> &vec) {
|
||||
std::sort(vec.begin(), vec.end());
|
||||
vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline bool all_space(const T &str) {
|
||||
return std::all_of(str.begin(), str.end(), [](unsigned char ch) { return std::isspace(ch); });
|
||||
}
|
||||
|
||||
static size_t utf8_truncate_safe(const std::string_view s) {
|
||||
size_t len = s.size();
|
||||
if (len == 0) return 0;
|
||||
size_t i = len;
|
||||
for (size_t back = 0; back < 4 && i > 0; ++back) {
|
||||
--i;
|
||||
unsigned char c = s[i];
|
||||
if ((c & 0x80) == 0) {
|
||||
return len;
|
||||
} else if ((c & 0xC0) == 0xC0) {
|
||||
size_t expected_len = 0;
|
||||
if ((c & 0xE0) == 0xC0) expected_len = 2;
|
||||
else if ((c & 0xF0) == 0xE0) expected_len = 3;
|
||||
else if ((c & 0xF8) == 0xF0) expected_len = 4;
|
||||
else return i;
|
||||
if (len - i >= expected_len) {
|
||||
return len;
|
||||
} else {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
return len - std::min(len, size_t(3));
|
||||
}
|
||||
|
||||
inline void utf8_truncate_safe_resize(std::string &s) {
|
||||
s.resize(utf8_truncate_safe(s));
|
||||
}
|
||||
|
||||
inline std::string_view utf8_truncate_safe_view(const std::string_view s) {
|
||||
return s.substr(0, utf8_truncate_safe(s));
|
||||
}
|
||||
|
||||
static std::optional<common_chat_msg_parser::find_regex_result> try_find_2_literal_splited_by_spaces(common_chat_msg_parser & builder, const std::string & literal1, const std::string & literal2) {
|
||||
if (literal1.size() == 0) return builder.try_find_literal(literal2);
|
||||
const auto saved_pos = builder.pos();
|
||||
while (auto res = builder.try_find_literal(literal1)) {
|
||||
builder.consume_spaces();
|
||||
const auto match_len = std::min(literal2.size(), builder.input().size() - builder.pos());
|
||||
if (builder.input().compare(builder.pos(), match_len, literal2, 0, match_len) == 0) {
|
||||
if (res->prelude.size() != res->groups[0].begin - saved_pos) {
|
||||
res->prelude = builder.str({saved_pos, res->groups[0].begin});
|
||||
}
|
||||
builder.move_to(builder.pos() + match_len);
|
||||
res->groups[0].end = builder.pos();
|
||||
GGML_ASSERT(res->groups[0].begin != res->groups[0].end);
|
||||
return res;
|
||||
}
|
||||
builder.move_to(res->groups[0].begin + 1);
|
||||
}
|
||||
builder.move_to(saved_pos);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
/**
|
||||
* make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||
*/
|
||||
std::string make_gbnf_excluding(std::vector<std::string> forbids) {
|
||||
constexpr auto charclass_escape = [](unsigned char c) -> std::string {
|
||||
if (c == '\\' || c == ']' || c == '^' || c == '-') {
|
||||
std::string s = "\\";
|
||||
s.push_back((char)c);
|
||||
return s;
|
||||
}
|
||||
if (isprint(c)) {
|
||||
return std::string(1, (char)c);
|
||||
}
|
||||
char buf[16];
|
||||
snprintf(buf, 15, "\\x%02X", c);
|
||||
return std::string(buf);
|
||||
};
|
||||
constexpr auto build_expr = [charclass_escape](auto self, const std::vector<std::string>& forbids, int l, int r, int depth) -> std::string {
|
||||
std::vector<std::pair<unsigned char, std::pair<int,int>>> children;
|
||||
int i = l;
|
||||
while (i < r) {
|
||||
const std::string &s = forbids[i];
|
||||
if ((int)s.size() == depth) {
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
unsigned char c = (unsigned char)s[depth];
|
||||
int j = i;
|
||||
while (j < r && (int)forbids[j].size() > depth &&
|
||||
(unsigned char)forbids[j][depth] == c) {
|
||||
++j;
|
||||
}
|
||||
children.push_back({c, {i, j}});
|
||||
i = j;
|
||||
}
|
||||
std::vector<std::string> alts;
|
||||
if (!children.empty()) {
|
||||
std::string cls;
|
||||
for (auto &ch : children) cls += charclass_escape(ch.first);
|
||||
alts.push_back(std::string("[^") + cls + "]");
|
||||
}
|
||||
for (auto &ch : children) {
|
||||
std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1);
|
||||
if (!childExpr.empty()) {
|
||||
std::string quoted_ch = "\"";
|
||||
if (ch.first == '\\') quoted_ch += "\\\\";
|
||||
else if (ch.first == '"') quoted_ch += "\\\"";
|
||||
else if (isprint(ch.first)) quoted_ch.push_back(ch.first);
|
||||
else {
|
||||
char buf[16];
|
||||
snprintf(buf, 15, "\\x%02X", ch.first);
|
||||
quoted_ch += buf;
|
||||
}
|
||||
quoted_ch += "\"";
|
||||
std::string branch = quoted_ch + std::string(" ") + childExpr;
|
||||
alts.push_back(branch);
|
||||
}
|
||||
}
|
||||
if (alts.empty()) return "";
|
||||
std::ostringstream oss;
|
||||
oss << "( ";
|
||||
for (size_t k = 0; k < alts.size(); ++k) {
|
||||
if (k) oss << " | ";
|
||||
oss << alts[k];
|
||||
}
|
||||
oss << " )";
|
||||
return oss.str();
|
||||
};
|
||||
if (forbids.empty()) return "( . )*";
|
||||
sort(forbids.begin(), forbids.end());
|
||||
std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0);
|
||||
if (expr.empty()) {
|
||||
std::string cls;
|
||||
for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]);
|
||||
expr = std::string("( [^") + cls + "] )";
|
||||
}
|
||||
if (forbids.size() == 1)
|
||||
return expr + "*";
|
||||
else
|
||||
return std::string("( ") + expr + " )*";
|
||||
}
|
||||
|
||||
/**
|
||||
* Build grammar for xml-style tool call
|
||||
* form.scope_start and form.scope_end can be empty.
|
||||
* Requires data.format for model-specific hacks.
|
||||
*/
|
||||
void build_grammar_xml_tool_call(common_chat_params & data, const json & tools, const struct xml_tool_call_format & form) {
|
||||
GGML_ASSERT(!form.tool_start.empty());
|
||||
GGML_ASSERT(!form.tool_sep.empty());
|
||||
GGML_ASSERT(!form.key_start.empty());
|
||||
GGML_ASSERT(!form.val_end.empty());
|
||||
GGML_ASSERT(!form.tool_end.empty());
|
||||
|
||||
std::string key_val_sep = form.key_val_sep;
|
||||
if (form.key_val_sep2) {
|
||||
key_val_sep += "\n";
|
||||
key_val_sep += *form.key_val_sep2;
|
||||
}
|
||||
GGML_ASSERT(!key_val_sep.empty());
|
||||
|
||||
if (tools.is_array() && !tools.empty()) {
|
||||
data.grammar = build_grammar([&](const common_grammar_builder &builder) {
|
||||
auto string_arg_val = form.last_val_end ?
|
||||
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end, *form.last_val_end})) :
|
||||
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end}));
|
||||
|
||||
std::vector<std::string> tool_rules;
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
||||
LOG_WRN("Skipping tool without function: %s", tool.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool.at("function");
|
||||
if (!function.contains("name") || !function.at("name").is_string()) {
|
||||
LOG_WRN("Skipping invalid function (invalid name): %s", function.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
|
||||
LOG_WRN("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
|
||||
struct parameter_rule {
|
||||
std::string symbol_name;
|
||||
bool is_required;
|
||||
};
|
||||
std::vector<parameter_rule> arg_rules;
|
||||
if (!parameters.contains("properties") || !parameters.at("properties").is_object()) {
|
||||
LOG_WRN("Skipping invalid function (invalid properties): %s", function.dump(2).c_str());
|
||||
continue;
|
||||
} else {
|
||||
std::vector<std::string> requiredParameters;
|
||||
if (parameters.contains("required")) {
|
||||
try { parameters.at("required").get_to(requiredParameters); }
|
||||
catch (const std::runtime_error&) {
|
||||
LOG_WRN("Invalid function required parameters, ignoring: %s", function.at("required").dump(2).c_str());
|
||||
}
|
||||
}
|
||||
sort_uniq(requiredParameters);
|
||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||
std::string quoted_key = key;
|
||||
bool required = std::binary_search(requiredParameters.begin(), requiredParameters.end(), key);
|
||||
if (form.key_start.back() == '"' && key_val_sep[0] == '"') {
|
||||
quoted_key = gbnf_format_literal(key);
|
||||
quoted_key = quoted_key.substr(1, quoted_key.size() - 2);
|
||||
}
|
||||
arg_rules.push_back(parameter_rule {builder.add_rule("func-" + name + "-kv-" + key,
|
||||
gbnf_format_literal(form.key_start) + " " +
|
||||
gbnf_format_literal(quoted_key) + " " +
|
||||
gbnf_format_literal(key_val_sep) + " " +
|
||||
((value.contains("type") && value["type"].is_string() && value["type"] == "string" && (!form.raw_argval || *form.raw_argval)) ?
|
||||
(form.raw_argval ?
|
||||
string_arg_val :
|
||||
"( " + string_arg_val + " | " + builder.add_schema(name + "-arg-" + key, value) + " )"
|
||||
) :
|
||||
builder.add_schema(name + "-arg-" + key, value)
|
||||
)
|
||||
), required});
|
||||
}
|
||||
}
|
||||
|
||||
auto next_arg_with_sep = builder.add_rule(name + "-last-arg-end", form.last_val_end ? gbnf_format_literal(*form.last_val_end) : gbnf_format_literal(form.val_end));
|
||||
decltype(next_arg_with_sep) next_arg = "\"\"";
|
||||
for (auto i = arg_rules.size() - 1; /* i >= 0 && */ i < arg_rules.size(); --i) {
|
||||
std::string include_this_arg = arg_rules[i].symbol_name + " " + next_arg_with_sep;
|
||||
next_arg = builder.add_rule(name + "-arg-after-" + std::to_string(i), arg_rules[i].is_required ?
|
||||
include_this_arg : "( " + include_this_arg + " ) | " + next_arg
|
||||
);
|
||||
include_this_arg = gbnf_format_literal(form.val_end) + " " + include_this_arg;
|
||||
next_arg_with_sep = builder.add_rule(name + "-arg-after-" + std::to_string(i) + "-with-sep", arg_rules[i].is_required ?
|
||||
include_this_arg : "( " + include_this_arg + " ) | " + next_arg_with_sep
|
||||
);
|
||||
}
|
||||
|
||||
std::string quoted_name = name;
|
||||
if (form.tool_start.back() == '"' && form.tool_sep[0] == '"') {
|
||||
quoted_name = gbnf_format_literal(name);
|
||||
quoted_name = quoted_name.substr(1, quoted_name.size() - 2);
|
||||
}
|
||||
quoted_name = gbnf_format_literal(quoted_name);
|
||||
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
||||
if (data.format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
||||
quoted_name = "\"functions.\" " + quoted_name + " \":\" [0-9]+";
|
||||
}
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
gbnf_format_literal(form.tool_start) + " " +
|
||||
quoted_name + " " +
|
||||
gbnf_format_literal(form.tool_sep) + " " +
|
||||
next_arg
|
||||
));
|
||||
}
|
||||
|
||||
auto tool_call_once = builder.add_rule("root-tool-call-once", string_join(tool_rules, " | "));
|
||||
auto tool_call_more = builder.add_rule("root-tool-call-more", gbnf_format_literal(form.tool_end) + " " + tool_call_once);
|
||||
auto call_end = builder.add_rule("root-call-end", form.last_tool_end ? gbnf_format_literal(*form.last_tool_end) : gbnf_format_literal(form.tool_end));
|
||||
auto tool_call_multiple_with_end = builder.add_rule("root-tool-call-multiple-with-end", tool_call_once + " " + tool_call_more + "* " + call_end);
|
||||
builder.add_rule("root",
|
||||
(form.scope_start.empty() ? "" : gbnf_format_literal(form.scope_start) + " ") +
|
||||
tool_call_multiple_with_end + "?" +
|
||||
(form.scope_end.empty() ? "" : " " + gbnf_format_literal(form.scope_end))
|
||||
);
|
||||
});
|
||||
|
||||
// grammar trigger for tool call
|
||||
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||
* Throws xml_toolcall_syntax_exception if there is invalid syntax and cannot recover the original status for common_chat_msg_parser.
|
||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||
*/
|
||||
inline bool parse_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form) {
|
||||
GGML_ASSERT(!form.tool_start.empty());
|
||||
GGML_ASSERT(!form.key_start.empty());
|
||||
GGML_ASSERT(!form.key_val_sep.empty());
|
||||
GGML_ASSERT(!form.val_end.empty());
|
||||
GGML_ASSERT(!form.tool_end.empty());
|
||||
|
||||
// Helper to choose return false or throw error
|
||||
constexpr auto return_error = [](common_chat_msg_parser & builder, auto &start_pos, const bool &recovery) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call at position: %s\n", gbnf_format_literal(builder.consume_rest().substr(0, 20)).c_str());
|
||||
if (recovery) {
|
||||
builder.move_to(start_pos);
|
||||
return false;
|
||||
} else throw xml_toolcall_syntax_exception("Tool call parsing failed with unrecoverable errors. Try using a grammar to constrain the model’s output.");
|
||||
};
|
||||
// Drop substring from needle to end from a JSON
|
||||
constexpr auto partial_json = [](std::string &json_str, std::string_view needle = "XML_TOOL_CALL_PARTIAL_FLAG") {
|
||||
auto pos = json_str.rfind(needle);
|
||||
if (pos == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
for (auto i = pos + needle.size(); i < json_str.size(); ++i) {
|
||||
unsigned char ch = static_cast<unsigned char>(json_str[i]);
|
||||
if (ch != '\'' && ch != '"' && ch != '}' && ch != ':' && !std::isspace(ch)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (pos != 0 && json_str[pos - 1] == '"') {
|
||||
--pos;
|
||||
}
|
||||
json_str.resize(pos);
|
||||
return true;
|
||||
};
|
||||
// Helper to generate a partial argument JSON
|
||||
constexpr auto gen_partial_json = [partial_json](auto set_partial_arg, auto &arguments, auto &builder, auto &function_name) {
|
||||
auto rest = builder.consume_rest();
|
||||
utf8_truncate_safe_resize(rest);
|
||||
set_partial_arg(rest, "XML_TOOL_CALL_PARTIAL_FLAG");
|
||||
auto tool_str = arguments.dump();
|
||||
if (partial_json(tool_str)) {
|
||||
if (builder.add_tool_call(function_name, "", tool_str)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOG_DBG("Failed to parse partial XML-Style tool call, fallback to non-partial: %s\n", tool_str.c_str());
|
||||
};
|
||||
// Helper to find a close (because there may be form.last_val_end or form.last_tool_end)
|
||||
constexpr auto try_find_close = [](
|
||||
common_chat_msg_parser & builder,
|
||||
const std::string & end,
|
||||
const std::optional<std::string> & alt_end,
|
||||
const std::string & end_next,
|
||||
const std::optional<std::string> & alt_end_next
|
||||
) {
|
||||
auto saved_pos = builder.pos();
|
||||
auto tc = builder.try_find_literal(end);
|
||||
auto val_end_size = end.size();
|
||||
if (alt_end) {
|
||||
auto pos_1 = builder.pos();
|
||||
builder.move_to(saved_pos);
|
||||
auto tc2 = try_find_2_literal_splited_by_spaces(builder, *alt_end, end_next);
|
||||
if (alt_end_next) {
|
||||
builder.move_to(saved_pos);
|
||||
auto tc3 = try_find_2_literal_splited_by_spaces(builder, *alt_end, *alt_end_next);
|
||||
if (tc3 && (!tc2 || tc2->prelude.size() > tc3->prelude.size())) {
|
||||
tc2 = tc3;
|
||||
}
|
||||
}
|
||||
if (tc2 && (!tc || tc->prelude.size() > tc2->prelude.size())) {
|
||||
tc = tc2;
|
||||
tc->groups[0].end = std::min(builder.input().size(), tc->groups[0].begin + alt_end->size());
|
||||
builder.move_to(tc->groups[0].end);
|
||||
val_end_size = alt_end->size();
|
||||
} else {
|
||||
builder.move_to(pos_1);
|
||||
}
|
||||
}
|
||||
return std::make_pair(val_end_size, tc);
|
||||
};
|
||||
// Helper to find a val_end or last_val_end, returns matched pattern size
|
||||
const auto try_find_val_end = [try_find_close, &builder, &form]() {
|
||||
return try_find_close(builder, form.val_end, form.last_val_end, form.tool_end, form.last_tool_end);
|
||||
};
|
||||
// Helper to find a tool_end or last_tool_end, returns matched pattern size
|
||||
const auto try_find_tool_end = [try_find_close, &builder, &form]() {
|
||||
return try_find_close(builder, form.tool_end, form.last_tool_end, form.scope_end, std::nullopt);
|
||||
};
|
||||
|
||||
bool recovery = true;
|
||||
const auto start_pos = builder.pos();
|
||||
if (!all_space(form.scope_start)) {
|
||||
if (auto tc = builder.try_find_literal(form.scope_start)) {
|
||||
if (all_space(tc->prelude)) {
|
||||
if (form.scope_start.size() != tc->groups[0].end - tc->groups[0].begin)
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.scope_start));
|
||||
} else {
|
||||
builder.move_to(start_pos);
|
||||
return false;
|
||||
}
|
||||
} else return false;
|
||||
}
|
||||
while (auto tc = builder.try_find_literal(form.tool_start)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
||||
gbnf_format_literal(form.tool_start).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
||||
break;
|
||||
}
|
||||
|
||||
// Find tool name
|
||||
auto func_name = builder.try_find_literal(all_space(form.tool_sep) ? form.key_start : form.tool_sep);
|
||||
if (!func_name) {
|
||||
auto [sz, tc] = try_find_tool_end();
|
||||
func_name = tc;
|
||||
}
|
||||
if (!func_name) {
|
||||
// Partial tool name not supported
|
||||
throw common_chat_msg_partial_exception("incomplete tool_call");
|
||||
}
|
||||
// If the model generate multiple tool call and the first tool call has no argument
|
||||
if (func_name->prelude.find(form.tool_end) != std::string::npos || (form.last_tool_end ? func_name->prelude.find(*form.last_tool_end) != std::string::npos : false)) {
|
||||
builder.move_to(func_name->groups[0].begin - func_name->prelude.size());
|
||||
auto [sz, tc] = try_find_tool_end();
|
||||
func_name = tc;
|
||||
}
|
||||
|
||||
// Parse tool name
|
||||
builder.move_to(all_space(form.tool_sep) ? func_name->groups[0].begin : func_name->groups[0].end);
|
||||
std::string function_name = string_strip(func_name->prelude);
|
||||
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
||||
if (builder.syntax().format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
||||
if (string_starts_with(function_name, "functions.")) {
|
||||
static const std::regex re(":\\d+$");
|
||||
if (std::regex_search(function_name, re)) {
|
||||
function_name = function_name.substr(10, function_name.rfind(":") - 10);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Argument JSON
|
||||
json arguments = json::object();
|
||||
|
||||
// Helper to generate a partial argument JSON
|
||||
const auto gen_partial_args = [&](auto set_partial_arg) {
|
||||
gen_partial_json(set_partial_arg, arguments, builder, function_name);
|
||||
};
|
||||
|
||||
// Parse all arg_key/arg_value pairs
|
||||
while (auto tc = builder.try_find_literal(form.key_start)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
||||
gbnf_format_literal(form.key_start).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
||||
break;
|
||||
}
|
||||
if (tc->groups[0].end - tc->groups[0].begin != form.key_start.size()) {
|
||||
auto tool_call_arg = arguments.dump();
|
||||
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
||||
tool_call_arg.resize(tool_call_arg.size() - 1);
|
||||
}
|
||||
builder.add_tool_call(function_name, "", tool_call_arg);
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_start));
|
||||
}
|
||||
|
||||
// Parse arg_key
|
||||
auto key_res = builder.try_find_literal(form.key_val_sep);
|
||||
if (!key_res) {
|
||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[rest + needle] = "";});
|
||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.key_val_sep) + " after " + gbnf_format_literal(form.key_start));
|
||||
}
|
||||
if (key_res->groups[0].end - key_res->groups[0].begin != form.key_val_sep.size()) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key_res->prelude + needle] = "";});
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_val_sep));
|
||||
}
|
||||
auto &key = key_res->prelude;
|
||||
recovery = false;
|
||||
|
||||
// Parse arg_value
|
||||
if (form.key_val_sep2) {
|
||||
if (auto tc = builder.try_find_literal(*form.key_val_sep2)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Unexcepted %s between %s and %s\n",
|
||||
gbnf_format_literal(tc->prelude).c_str(),
|
||||
gbnf_format_literal(form.key_val_sep).c_str(),
|
||||
gbnf_format_literal(*form.key_val_sep2).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, false);
|
||||
}
|
||||
if (tc->groups[0].end - tc->groups[0].begin != form.key_val_sep2->size()) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(*form.key_val_sep2));
|
||||
}
|
||||
} else {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(*form.key_val_sep2) + " after " + gbnf_format_literal(form.key_val_sep));
|
||||
}
|
||||
}
|
||||
auto val_start = builder.pos();
|
||||
|
||||
// Test if arg_val is a partial JSON
|
||||
std::optional<common_json> value_json = std::nullopt;
|
||||
if (!form.raw_argval || !*form.raw_argval) {
|
||||
try { value_json = builder.try_consume_json(); }
|
||||
catch (const std::runtime_error&) { builder.move_to(val_start); }
|
||||
// TODO: Delete this when json_partial adds top-level support for null/true/false
|
||||
if (builder.pos() == val_start) {
|
||||
const static std::regex number_regex(R"([0-9-][0-9]*(\.\d*)?([eE][+-]?\d*)?)");
|
||||
builder.consume_spaces();
|
||||
std::string_view sv = utf8_truncate_safe_view(builder.input());
|
||||
sv.remove_prefix(builder.pos());
|
||||
std::string rest = "a";
|
||||
if (sv.size() < 6) rest = sv;
|
||||
if (string_starts_with("null", rest) || string_starts_with("true", rest) || string_starts_with("false", rest) || std::regex_match(sv.begin(), sv.end(), number_regex)) {
|
||||
value_json = {123, {"123", "123"}};
|
||||
builder.consume_rest();
|
||||
} else {
|
||||
builder.move_to(val_start);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If it is a JSON and followed by </arg_value>, parse as json
|
||||
// cannot support streaming because it may be a plain text starting with JSON
|
||||
if (value_json) {
|
||||
auto json_end = builder.pos();
|
||||
builder.consume_spaces();
|
||||
if (builder.pos() == builder.input().size()) {
|
||||
if (form.raw_argval && !*form.raw_argval && (value_json->json.is_string() || value_json->json.is_object() || value_json->json.is_array())) {
|
||||
arguments[key] = value_json->json;
|
||||
auto json_str = arguments.dump();
|
||||
if (!value_json->healing_marker.json_dump_marker.empty()) {
|
||||
GGML_ASSERT(std::string::npos != json_str.rfind(value_json->healing_marker.json_dump_marker));
|
||||
json_str.resize(json_str.rfind(value_json->healing_marker.json_dump_marker));
|
||||
} else {
|
||||
GGML_ASSERT(json_str.back() == '}');
|
||||
json_str.resize(json_str.size() - 1);
|
||||
}
|
||||
builder.add_tool_call(function_name, "", json_str);
|
||||
} else {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
}
|
||||
LOG_DBG("Possible JSON arg_value: %s\n", value_json->json.dump().c_str());
|
||||
throw common_chat_msg_partial_exception("JSON arg_value detected. Waiting for more tokens for validations.");
|
||||
}
|
||||
builder.move_to(json_end);
|
||||
auto [val_end_size, tc] = try_find_val_end();
|
||||
if (tc && all_space(tc->prelude) && value_json->healing_marker.marker.empty()) {
|
||||
if (tc->groups[0].end - tc->groups[0].begin != val_end_size) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
LOG_DBG("Possible terminated JSON arg_value: %s\n", value_json->json.dump().c_str());
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.val_end) + (form.last_val_end ? gbnf_format_literal(*form.last_val_end) : ""));
|
||||
} else arguments[key] = value_json->json;
|
||||
} else builder.move_to(val_start);
|
||||
}
|
||||
|
||||
// If not, parse as plain text
|
||||
if (val_start == builder.pos()) {
|
||||
if (auto [val_end_size, value_plain] = try_find_val_end(); value_plain) {
|
||||
auto &value_str = value_plain->prelude;
|
||||
if (form.trim_raw_argval) value_str = string_strip(value_str);
|
||||
if (value_plain->groups[0].end - value_plain->groups[0].begin != val_end_size) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = value_str + needle;});
|
||||
throw common_chat_msg_partial_exception(
|
||||
"Expected " + gbnf_format_literal(form.val_end) +
|
||||
" after " + gbnf_format_literal(form.key_val_sep) +
|
||||
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
||||
);
|
||||
}
|
||||
arguments[key] = value_str;
|
||||
} else {
|
||||
if (form.trim_raw_argval) {
|
||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = string_strip(rest) + needle;});
|
||||
} else {
|
||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = rest + needle;});
|
||||
}
|
||||
throw common_chat_msg_partial_exception(
|
||||
"Expected " + gbnf_format_literal(form.val_end) +
|
||||
" after " + gbnf_format_literal(form.key_val_sep) +
|
||||
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Consume closing tag
|
||||
if (auto [tool_end_size, tc] = try_find_tool_end(); tc) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||
gbnf_format_literal(form.tool_end).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, recovery);
|
||||
}
|
||||
if (tc->groups[0].end - tc->groups[0].begin == tool_end_size) {
|
||||
// Add the parsed tool call
|
||||
if (!builder.add_tool_call(function_name, "", arguments.dump())) {
|
||||
throw common_chat_msg_partial_exception("Failed to add XML-Style tool call");
|
||||
}
|
||||
recovery = false;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto tool_call_arg = arguments.dump();
|
||||
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
||||
tool_call_arg.resize(tool_call_arg.size() - 1);
|
||||
}
|
||||
builder.add_tool_call(function_name, "", tool_call_arg);
|
||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.tool_end) + " after " + gbnf_format_literal(form.val_end));
|
||||
}
|
||||
if (auto tc = builder.try_find_literal(form.scope_end)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||
gbnf_format_literal(form.scope_end).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, recovery);
|
||||
}
|
||||
} else {
|
||||
if (all_space(form.scope_end)) return true;
|
||||
builder.consume_spaces();
|
||||
if (builder.pos() == builder.input().size())
|
||||
throw common_chat_msg_partial_exception("incomplete tool calls");
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||
gbnf_format_literal(form.scope_end).c_str(),
|
||||
gbnf_format_literal(builder.consume_rest()).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, recovery);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||
* May cause std::runtime_error if there is invalid syntax because partial valid tool call is already sent out to client.
|
||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||
*/
|
||||
bool common_chat_msg_parser::try_consume_xml_tool_calls(const struct xml_tool_call_format & form) {
|
||||
auto pos = pos_;
|
||||
auto tsize = result_.tool_calls.size();
|
||||
try { return parse_xml_tool_calls(*this, form); }
|
||||
catch (const xml_toolcall_syntax_exception&) {}
|
||||
move_to(pos);
|
||||
result_.tool_calls.resize(tsize);
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse content uses reasoning and XML-Style tool call
|
||||
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
|
||||
*/
|
||||
inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>") {
|
||||
constexpr auto rstrip = [](std::string &s) {
|
||||
s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base()));
|
||||
};
|
||||
// Erase substring from l to r, along with additional spaces nearby
|
||||
constexpr auto erase_spaces = [](auto &str, size_t l, size_t r) {
|
||||
while (/* l > -1 && */ --l < str.size() && std::isspace(static_cast<unsigned char>(str[l])));
|
||||
++l;
|
||||
while (++r < str.size() && std::isspace(static_cast<unsigned char>(str[r])));
|
||||
if (l < r) str[l] = '\n';
|
||||
if (l + 1 < r) str[l + 1] = '\n';
|
||||
if (l != 0) l += 2;
|
||||
str.erase(l, r - l);
|
||||
return l;
|
||||
};
|
||||
constexpr auto trim_suffix = [](std::string &content, std::initializer_list<std::string_view> list) {
|
||||
auto best_match = content.size();
|
||||
for (auto pattern: list) {
|
||||
if (pattern.size() == 0) continue;
|
||||
for (auto match_idx = content.size() - std::min(pattern.size(), content.size()); content.size() > match_idx; match_idx++) {
|
||||
auto match_len = content.size() - match_idx;
|
||||
if (content.compare(match_idx, match_len, pattern.data(), match_len) == 0 && best_match > match_idx) {
|
||||
best_match = match_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (content.size() > best_match) {
|
||||
content.erase(best_match);
|
||||
}
|
||||
};
|
||||
const auto trim_potential_partial_word = [&start_think, &end_think, &form, trim_suffix](std::string &content) {
|
||||
return trim_suffix(content, {
|
||||
start_think, end_think, form.scope_start, form.tool_start, form.tool_sep, form.key_start,
|
||||
form.key_val_sep, form.key_val_sep2 ? form.key_val_sep2->c_str() : "",
|
||||
form.val_end, form.last_val_end ? form.last_val_end->c_str() : "",
|
||||
form.tool_end, form.last_tool_end ? form.last_tool_end->c_str() : "",
|
||||
form.scope_end
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
// Trim leading spaces without affecting keyword matching
|
||||
static const common_regex spaces_regex("\\s*");
|
||||
{
|
||||
auto tc = builder.consume_regex(spaces_regex);
|
||||
auto spaces = builder.str(tc.groups[0]);
|
||||
auto s1 = spaces.size();
|
||||
trim_potential_partial_word(spaces);
|
||||
auto s2 = spaces.size();
|
||||
builder.move_to(builder.pos() - (s1 - s2));
|
||||
}
|
||||
|
||||
// Parse content
|
||||
bool reasoning_unclosed = builder.syntax().thinking_forced_open;
|
||||
std::string unclosed_reasoning_content("");
|
||||
for (;;) {
|
||||
auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start);
|
||||
std::string content;
|
||||
std::string tool_call_start;
|
||||
|
||||
if (tc) {
|
||||
content = std::move(tc->prelude);
|
||||
tool_call_start = builder.str(tc->groups[0]);
|
||||
LOG_DBG("Matched tool start: %s\n", gbnf_format_literal(tool_call_start).c_str());
|
||||
} else {
|
||||
content = builder.consume_rest();
|
||||
utf8_truncate_safe_resize(content);
|
||||
}
|
||||
|
||||
// Handle unclosed think block
|
||||
if (reasoning_unclosed) {
|
||||
if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) {
|
||||
unclosed_reasoning_content += content;
|
||||
if (form.allow_toolcall_in_think) {
|
||||
builder.move_to(tc->groups[0].begin);
|
||||
if (!builder.try_consume_xml_tool_calls(form)) {
|
||||
unclosed_reasoning_content += tool_call_start;
|
||||
builder.move_to(tc->groups[0].end);
|
||||
}
|
||||
} else {
|
||||
unclosed_reasoning_content += tool_call_start;
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
reasoning_unclosed = false;
|
||||
std::string reasoning_content;
|
||||
if (pos == std::string::npos) {
|
||||
reasoning_content = std::move(content);
|
||||
} else {
|
||||
reasoning_content = content.substr(0, pos);
|
||||
content.erase(0, pos + end_think.size());
|
||||
}
|
||||
if (builder.pos() == builder.input().size() && all_space(content)) {
|
||||
rstrip(reasoning_content);
|
||||
trim_potential_partial_word(reasoning_content);
|
||||
rstrip(reasoning_content);
|
||||
if (reasoning_content.empty()) {
|
||||
rstrip(unclosed_reasoning_content);
|
||||
trim_potential_partial_word(unclosed_reasoning_content);
|
||||
rstrip(unclosed_reasoning_content);
|
||||
if (unclosed_reasoning_content.empty()) continue;
|
||||
}
|
||||
}
|
||||
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
|
||||
builder.add_content(start_think);
|
||||
builder.add_content(unclosed_reasoning_content);
|
||||
builder.add_content(reasoning_content);
|
||||
if (builder.pos() != builder.input().size() || !all_space(content))
|
||||
builder.add_content(end_think);
|
||||
} else {
|
||||
builder.add_reasoning_content(unclosed_reasoning_content);
|
||||
builder.add_reasoning_content(reasoning_content);
|
||||
}
|
||||
unclosed_reasoning_content.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Handle multiple think block
|
||||
bool toolcall_in_think = false;
|
||||
for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) {
|
||||
if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) {
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||
auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size());
|
||||
builder.add_reasoning_content(reasoning_content);
|
||||
think_start = erase_spaces(content, think_start, think_end + end_think.size() - 1);
|
||||
} else {
|
||||
think_start = think_end + end_think.size() - 1;
|
||||
}
|
||||
} else {
|
||||
// This <tool_call> start is in thinking block, skip this tool call
|
||||
auto pos = think_start + start_think.size();
|
||||
unclosed_reasoning_content = content.substr(pos) + tool_call_start;
|
||||
reasoning_unclosed = true;
|
||||
content.resize(think_start);
|
||||
toolcall_in_think = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||
rstrip(content);
|
||||
// Handle unclosed </think> token from content: delete all </think> token
|
||||
if (auto pos = content.rfind(end_think); pos != std::string::npos) {
|
||||
while (pos != std::string::npos) {
|
||||
pos = erase_spaces(content, pos, pos + end_think.size() - 1);
|
||||
pos = content.rfind(end_think, pos);
|
||||
}
|
||||
}
|
||||
// Strip if needed
|
||||
if (content.size() > 0 && std::isspace(static_cast<unsigned char>(content[0]))) {
|
||||
content = string_strip(content);
|
||||
}
|
||||
}
|
||||
|
||||
// remove potential partial suffix
|
||||
if (content.size() > 0 && builder.pos() == builder.input().size() && unclosed_reasoning_content.empty()) {
|
||||
rstrip(content);
|
||||
trim_potential_partial_word(content);
|
||||
rstrip(content);
|
||||
}
|
||||
|
||||
// Add content
|
||||
if (content.size() != 0) {
|
||||
// If there are multiple content blocks
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) {
|
||||
builder.add_content("\n\n");
|
||||
}
|
||||
builder.add_content(content);
|
||||
}
|
||||
|
||||
// This <tool_call> start is in thinking block, skip this tool call
|
||||
if (toolcall_in_think && !form.allow_toolcall_in_think) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// There is no tool call and all content is parsed
|
||||
if (!tc) {
|
||||
GGML_ASSERT(builder.pos() == builder.input().size());
|
||||
GGML_ASSERT(unclosed_reasoning_content.empty());
|
||||
GGML_ASSERT(!reasoning_unclosed);
|
||||
break;
|
||||
}
|
||||
|
||||
builder.move_to(tc->groups[0].begin);
|
||||
if (builder.try_consume_xml_tool_calls(form)) {
|
||||
auto end_of_tool = builder.pos();
|
||||
builder.consume_spaces();
|
||||
if (builder.pos() != builder.input().size()) {
|
||||
builder.move_to(end_of_tool);
|
||||
if (!builder.result().content.empty()) {
|
||||
builder.add_content("\n\n");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static const common_regex next_char_regex(".");
|
||||
auto c = builder.str(builder.consume_regex(next_char_regex).groups[0]);
|
||||
rstrip(c);
|
||||
builder.add_content(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse content uses reasoning and XML-Style tool call
|
||||
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
|
||||
*/
|
||||
void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) {
|
||||
parse_msg_with_xml_tool_calls(*this, form, start_think, end_think);
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
||||
// Sample config:
|
||||
// MiniMax-M2 (left): <minimax:tool_call>\n<invoke name="tool-name">\n<parameter name="key">value</parameter>\n...</invoke>\n...</minimax:tool_call>
|
||||
// GLM 4.5 (right): <tool_call>function_name\n<arg_key>key</arg_key>\n<arg_value>value</arg_value>\n</tool_call>
|
||||
struct xml_tool_call_format {
|
||||
std::string scope_start; // <minimax:tool_call>\n // \n // can be empty
|
||||
std::string tool_start; // <invoke name=\" // <tool_call>
|
||||
std::string tool_sep; // \">\n // \n // can be empty only for parse_xml_tool_calls
|
||||
std::string key_start; // <parameter name=\" // <arg_key>
|
||||
std::string key_val_sep; // \"> // </arg_key>\n<arg_value>
|
||||
std::string val_end; // </parameter>\n // </arg_value>\n
|
||||
std::string tool_end; // </invoke>\n // </tool_call>\n
|
||||
std::string scope_end; // </minimax:tool_call> // // can be empty
|
||||
// Set this if there can be dynamic spaces inside key_val_sep.
|
||||
// e.g. key_val_sep=</arg_key> key_val_sep2=<arg_value> for GLM4.5
|
||||
std::optional<std::string> key_val_sep2 = std::nullopt;
|
||||
// Set true if argval should only be raw string. e.g. Hello "world" hi
|
||||
// Set false if argval should only be json string. e.g. "Hello \"world\" hi"
|
||||
// Defaults to std::nullopt, both will be allowed.
|
||||
std::optional<bool> raw_argval = std::nullopt;
|
||||
std::optional<std::string> last_val_end = std::nullopt;
|
||||
std::optional<std::string> last_tool_end = std::nullopt;
|
||||
bool trim_raw_argval = false;
|
||||
bool allow_toolcall_in_think = false; // TODO: UNTESTED!!!
|
||||
};
|
||||
|
||||
// make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||
std::string make_gbnf_excluding(std::vector<std::string> forbids);
|
||||
|
||||
/**
|
||||
* Build grammar for xml-style tool call
|
||||
* form.scope_start and form.scope_end can be empty.
|
||||
* Requires data.format for model-specific hacks.
|
||||
*/
|
||||
void build_grammar_xml_tool_call(common_chat_params & data, const nlohmann::ordered_json & tools, const struct xml_tool_call_format & form);
|
||||
@@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
#include "chat-parser-xml-toolcall.h"
|
||||
#include "json-partial.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
@@ -120,14 +119,5 @@ class common_chat_msg_parser {
|
||||
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||
);
|
||||
|
||||
/**
|
||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||
*/
|
||||
bool try_consume_xml_tool_calls(const struct xml_tool_call_format & form);
|
||||
|
||||
// Parse content uses reasoning and XML-Style tool call
|
||||
void consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>");
|
||||
|
||||
void clear_tools();
|
||||
};
|
||||
|
||||
548
common/chat.cpp
548
common/chat.cpp
@@ -643,12 +643,6 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
|
||||
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
|
||||
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools";
|
||||
case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2";
|
||||
case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5";
|
||||
case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2";
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
|
||||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
@@ -1813,278 +1807,6 @@ static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
data.prompt = apply(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_MINIMAX_M2;
|
||||
|
||||
// Handle thinking tags based on prompt ending
|
||||
if (string_ends_with(data.prompt, "<think>\n")) {
|
||||
if (!params.enable_thinking) {
|
||||
// Close the thinking tag immediately if thinking is disabled
|
||||
data.prompt += "</think>\n\n";
|
||||
} else {
|
||||
// Mark thinking as forced open (template started with <think>)
|
||||
data.thinking_forced_open = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Preserve MiniMax-M2 special tokens
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<minimax:tool_call>",
|
||||
"</minimax:tool_call>",
|
||||
};
|
||||
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<minimax:tool_call>\n",
|
||||
/* form.tool_start = */ "<invoke name=\"",
|
||||
/* form.tool_sep = */ "\">\n",
|
||||
/* form.key_start = */ "<parameter name=\"",
|
||||
/* form.key_val_sep = */ "\">",
|
||||
/* form.val_end = */ "</parameter>\n",
|
||||
/* form.tool_end = */ "</invoke>\n",
|
||||
/* form.scope_end = */ "</minimax:tool_call>",
|
||||
};
|
||||
build_grammar_xml_tool_call(data, params.tools, form);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<minimax:tool_call>",
|
||||
/* form.tool_start = */ "<invoke name=\"",
|
||||
/* form.tool_sep = */ "\">",
|
||||
/* form.key_start = */ "<parameter name=\"",
|
||||
/* form.key_val_sep = */ "\">",
|
||||
/* form.val_end = */ "</parameter>",
|
||||
/* form.tool_end = */ "</invoke>",
|
||||
/* form.scope_end = */ "</minimax:tool_call>",
|
||||
};
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
data.prompt = apply(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_QWEN3_CODER_XML;
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
"<function=",
|
||||
"</function>",
|
||||
"<parameter=",
|
||||
"</parameter>",
|
||||
};
|
||||
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<tool_call>\n",
|
||||
/* form.tool_start = */ "<function=",
|
||||
/* form.tool_sep = */ ">\n",
|
||||
/* form.key_start = */ "<parameter=",
|
||||
/* form.key_val_sep = */ ">\n",
|
||||
/* form.val_end = */ "\n</parameter>\n",
|
||||
/* form.tool_end = */ "</function>\n",
|
||||
/* form.scope_end = */ "</tool_call>",
|
||||
};
|
||||
build_grammar_xml_tool_call(data, params.tools, form);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_call>";
|
||||
form.tool_start = "<function=";
|
||||
form.tool_sep = ">";
|
||||
form.key_start = "<parameter=";
|
||||
form.key_val_sep = ">";
|
||||
form.val_end = "</parameter>";
|
||||
form.tool_end = "</function>";
|
||||
form.scope_end = "</tool_call>";
|
||||
form.trim_raw_argval = true;
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
data.prompt = apply(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_KIMI_K2;
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_call_begin|>",
|
||||
"<|tool_call_argument_begin|>",
|
||||
"<|tool_call_end|>",
|
||||
"<|tool_calls_section_end|>",
|
||||
"<|im_end|>",
|
||||
"<|im_system|>",
|
||||
"<|im_middle|>",
|
||||
};
|
||||
|
||||
data.additional_stops.insert(data.additional_stops.end(), {
|
||||
"<|im_end|>",
|
||||
"<|im_middle|>"
|
||||
});
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<|tool_calls_section_begin|>";
|
||||
form.tool_start = "<|tool_call_begin|>";
|
||||
form.tool_sep = "<|tool_call_argument_begin|>{";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}<|tool_call_end|>";
|
||||
form.scope_end = "<|tool_calls_section_end|>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
return form;
|
||||
})();
|
||||
build_grammar_xml_tool_call(data, params.tools, form);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<|tool_calls_section_begin|>";
|
||||
form.tool_start = "<|tool_call_begin|>";
|
||||
form.tool_sep = "<|tool_call_argument_begin|>{";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}<|tool_call_end|>";
|
||||
form.scope_end = "<|tool_calls_section_end|>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
data.prompt = apply(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_APRIEL_1_5;
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<thinking>",
|
||||
"</thinking>",
|
||||
"<tool_calls>",
|
||||
"</tool_calls>",
|
||||
};
|
||||
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_calls>[";
|
||||
form.tool_start = "{\"name\": \"";
|
||||
form.tool_sep = "\", \"arguments\": {";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}, ";
|
||||
form.scope_end = "]</tool_calls>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
form.last_tool_end = "}";
|
||||
return form;
|
||||
})();
|
||||
build_grammar_xml_tool_call(data, params.tools, form);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_calls>[";
|
||||
form.tool_start = "{\"name\": \"";
|
||||
form.tool_sep = "\", \"arguments\": {";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}, ";
|
||||
form.scope_end = "]</tool_calls>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
form.last_tool_end = "}";
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<thinking>", "</thinking>");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
data.prompt = apply(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_XIAOMI_MIMO;
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
};
|
||||
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "\n";
|
||||
form.tool_start = "<tool_call>\n{\"name\": \"";
|
||||
form.tool_sep = "\", \"arguments\": {";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}\n</tool_call>";
|
||||
form.scope_end = "";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
return form;
|
||||
})();
|
||||
build_grammar_xml_tool_call(data, params.tools, form);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "";
|
||||
form.tool_start = "<tool_call>\n{\"name\": \"";
|
||||
form.tool_sep = "\", \"arguments\": {";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}\n</tool_call>";
|
||||
form.scope_end = "";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
@@ -2319,100 +2041,6 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tools.is_array() && !inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
std::string prompt = apply(tmpl, inputs);
|
||||
|
||||
// match the existing trimming behavior
|
||||
if (inputs.add_bos && string_starts_with(prompt, tmpl.bos_token())) {
|
||||
prompt.erase(0, tmpl.bos_token().size());
|
||||
}
|
||||
if (inputs.add_eos && string_ends_with(prompt, tmpl.eos_token())) {
|
||||
prompt.erase(prompt.size() - tmpl.eos_token().size());
|
||||
}
|
||||
if (string_ends_with(prompt, "<think>")) {
|
||||
if (!inputs.enable_thinking) {
|
||||
prompt += "</think>";
|
||||
} else {
|
||||
data.thinking_forced_open = true;
|
||||
}
|
||||
}
|
||||
|
||||
// add GLM preserved tokens
|
||||
data.preserved_tokens = {
|
||||
"<|endoftext|>",
|
||||
"[MASK]",
|
||||
"[gMASK]",
|
||||
"[sMASK]",
|
||||
"<sop>",
|
||||
"<eop>",
|
||||
"<|system|>",
|
||||
"<|user|>",
|
||||
"<|assistant|>",
|
||||
"<|observation|>",
|
||||
"<|begin_of_image|>",
|
||||
"<|end_of_image|>",
|
||||
"<|begin_of_video|>",
|
||||
"<|end_of_video|>",
|
||||
"<|begin_of_audio|>",
|
||||
"<|end_of_audio|>",
|
||||
"<|begin_of_transcription|>",
|
||||
"<|end_of_transcription|>",
|
||||
"<|code_prefix|>",
|
||||
"<|code_middle|>",
|
||||
"<|code_suffix|>",
|
||||
"/nothink",
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
"<arg_key>",
|
||||
"</arg_key>",
|
||||
"<arg_value>",
|
||||
"</arg_value>"
|
||||
};
|
||||
|
||||
// extra GLM 4.5 stop word
|
||||
data.additional_stops.insert(data.additional_stops.end(), {
|
||||
"<|user|>",
|
||||
"<|observation|>"
|
||||
});
|
||||
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "",
|
||||
/* form.tool_start = */ "\n<tool_call>",
|
||||
/* form.tool_sep = */ "\n",
|
||||
/* form.key_start = */ "<arg_key>",
|
||||
/* form.key_val_sep = */ "</arg_key>\n<arg_value>",
|
||||
/* form.val_end = */ "</arg_value>\n",
|
||||
/* form.tool_end = */ "</tool_call>\n",
|
||||
/* form.scope_end = */ "",
|
||||
};
|
||||
build_grammar_xml_tool_call(data, inputs.tools, form);
|
||||
|
||||
data.prompt = prompt;
|
||||
data.format = COMMON_CHAT_FORMAT_GLM_4_5;
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "",
|
||||
/* form.tool_start = */ "<tool_call>",
|
||||
/* form.tool_sep = */ "",
|
||||
/* form.key_start = */ "<arg_key>",
|
||||
/* form.key_val_sep = */ "</arg_key>",
|
||||
/* form.val_end = */ "</arg_value>",
|
||||
/* form.tool_end = */ "</tool_call>",
|
||||
/* form.scope_end = */ "",
|
||||
/* form.key_val_sep2 = */ "<arg_value>",
|
||||
};
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
common_chat_params data;
|
||||
@@ -3076,17 +2704,91 @@ static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
|
||||
}
|
||||
|
||||
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<seed:tool_call>",
|
||||
/* form.tool_start = */ "<function=",
|
||||
/* form.tool_sep = */ ">",
|
||||
/* form.key_start = */ "<parameter=",
|
||||
/* form.key_val_sep = */ ">",
|
||||
/* form.val_end = */ "</parameter>",
|
||||
/* form.tool_end = */ "</function>",
|
||||
/* form.scope_end = */ "</seed:tool_call>",
|
||||
};
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>");
|
||||
// Parse thinking tags first - this handles the main reasoning content
|
||||
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse tool calls - Seed-OSS uses <seed:tool_call> format
|
||||
static const common_regex tool_call_begin_regex("<seed:tool_call>");
|
||||
static const common_regex tool_call_end_regex("</seed:tool_call>");
|
||||
static const common_regex function_regex("<function=([^>]+)>");
|
||||
static const common_regex param_regex("<parameter=([^>]+)>");
|
||||
|
||||
while (auto tool_res = builder.try_find_regex(tool_call_begin_regex)) {
|
||||
builder.consume_spaces(); // Consume whitespace after <seed:tool_call>
|
||||
|
||||
// Look for function call inside tool call, ignore any content before it
|
||||
if (auto func_res = builder.try_find_regex(function_regex, std::string::npos, false)) {
|
||||
auto function_name = builder.str(func_res->groups[1]);
|
||||
|
||||
// Parse Seed-OSS parameters <parameter=name>value</parameter>
|
||||
json args = json::object();
|
||||
// Parse all parameters
|
||||
while (auto param_res = builder.try_find_regex(param_regex, std::string::npos, false)) {
|
||||
// again, ignore noise around parameters
|
||||
auto param_name = builder.str(param_res->groups[1]);
|
||||
builder.move_to(param_res->groups[0].end);
|
||||
builder.consume_spaces(); // Consume whitespace after parameter
|
||||
auto savedPos = builder.pos();
|
||||
if (auto param_parse = builder.try_find_literal("</parameter>")) {
|
||||
auto param = param_parse->prelude;
|
||||
builder.move_to(savedPos);
|
||||
try {
|
||||
if (auto param_res = builder.try_consume_json()) {
|
||||
args[param_name] = param_res->json;
|
||||
} else {
|
||||
args[param_name] = param;
|
||||
}
|
||||
} catch (json::exception &) {
|
||||
args[param_name] = param;
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool parameter");
|
||||
}
|
||||
}
|
||||
// Look for closing function tag
|
||||
auto end_func = builder.try_find_literal("</function>");
|
||||
if (end_func) {
|
||||
builder.move_to(end_func->groups[0].end);
|
||||
builder.consume_spaces(); // Consume whitespace after </function>
|
||||
|
||||
// Add the tool call with parsed arguments, but only if we REALLY got the literal
|
||||
auto eaten_fragment = builder.input().substr(end_func->groups[0].begin, end_func->groups[0].end);
|
||||
auto funlen = std::string("</function>").length();
|
||||
if (eaten_fragment.length() >= funlen && eaten_fragment.substr(0, funlen) == std::string("</function>")) {
|
||||
if (!builder.add_tool_call(function_name, "", args.dump())) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
// Look for closing tool call tag
|
||||
if (auto end_tool = builder.try_find_regex(tool_call_end_regex, std::string::npos, false)) {
|
||||
builder.move_to(end_tool->groups[0].end);
|
||||
builder.consume_spaces(); // Consume trailing whitespace after tool call
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
} else {
|
||||
// No function found - don't consume content here, let it be handled at the end
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Consume any remaining whitespace after all tool call processing
|
||||
builder.consume_spaces();
|
||||
auto remaining = builder.consume_rest();
|
||||
// If there's any non-whitespace content remaining, add it as content
|
||||
if (!string_strip(remaining).empty()) {
|
||||
builder.add_content(remaining);
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
@@ -3225,35 +2927,6 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_granite(tmpl, params);
|
||||
}
|
||||
|
||||
// GLM 4.5: detect by <arg_key> and <arg_value> tags (check before Hermes since both use <tool_call>)
|
||||
if (src.find("[gMASK]<sop>") != std::string::npos &&
|
||||
src.find("<arg_key>") != std::string::npos &&
|
||||
src.find("<arg_value>") != std::string::npos &&
|
||||
params.json_schema.is_null()) {
|
||||
return common_chat_params_init_glm_4_5(tmpl, params);
|
||||
}
|
||||
|
||||
// Qwen3-Coder XML format detection (must come before Hermes 2 Pro)
|
||||
// Detect via explicit XML markers unique to Qwen3-Coder to avoid false positives in other templates.
|
||||
// Require presence of <tool_call>, <function=...>, and <parameter=...> blocks.
|
||||
if (src.find("<tool_call>") != std::string::npos &&
|
||||
src.find("<function>") != std::string::npos &&
|
||||
src.find("<function=") != std::string::npos &&
|
||||
src.find("<parameters>") != std::string::npos &&
|
||||
src.find("<parameter=") != std::string::npos) {
|
||||
return common_chat_params_init_qwen3_coder_xml(tmpl, params);
|
||||
}
|
||||
|
||||
// Xiaomi MiMo format detection (must come before Hermes 2 Pro)
|
||||
if (src.find("<tools>") != std::string::npos &&
|
||||
src.find("# Tools") != std::string::npos &&
|
||||
src.find("</tools>") != std::string::npos &&
|
||||
src.find("<tool_calls>") != std::string::npos &&
|
||||
src.find("</tool_calls>") != std::string::npos &&
|
||||
src.find("<tool_response>") != std::string::npos) {
|
||||
return common_chat_params_init_xiaomi_mimo(tmpl, params);
|
||||
}
|
||||
|
||||
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
||||
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
|
||||
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
||||
@@ -3285,29 +2958,6 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_lfm2(tmpl, params);
|
||||
}
|
||||
|
||||
// MiniMax-M2 format detection
|
||||
if (src.find("]~!b[") != std::string::npos && src.find("]~b]") != std::string::npos) {
|
||||
return common_chat_params_init_minimax_m2(tmpl, params);
|
||||
}
|
||||
|
||||
// Kimi K2 format detection
|
||||
if (src.find("<|im_system|>tool_declare<|im_middle|>") != std::string::npos &&
|
||||
src.find("<|tool_calls_section_begin|>") != std::string::npos &&
|
||||
src.find("## Return of") != std::string::npos) {
|
||||
return common_chat_params_init_kimi_k2(tmpl, params);
|
||||
}
|
||||
|
||||
// Apriel 1.5 format detection
|
||||
if (src.find("<thinking>") != std::string::npos &&
|
||||
src.find("</thinking>") != std::string::npos &&
|
||||
src.find("<available_tools>") != std::string::npos &&
|
||||
src.find("<|assistant|>") != std::string::npos &&
|
||||
src.find("<|tool_result|>") != std::string::npos &&
|
||||
src.find("<tool_calls>[") != std::string::npos &&
|
||||
src.find("]</tool_calls>") != std::string::npos) {
|
||||
return common_chat_params_init_apriel_1_5(tmpl, params);
|
||||
}
|
||||
|
||||
// Use generic handler when mixing tools + JSON schema.
|
||||
// TODO: support that mix in handlers below.
|
||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||
@@ -3359,7 +3009,7 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||
const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs)
|
||||
{
|
||||
size_t alloc_size = 0;
|
||||
int alloc_size = 0;
|
||||
std::vector<llama_chat_message> chat;
|
||||
std::vector<std::string> contents;
|
||||
|
||||
@@ -3381,8 +3031,7 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||
const auto & msg = inputs.messages[i];
|
||||
const auto & content = contents[i];
|
||||
chat.push_back({msg.role.c_str(), content.c_str()});
|
||||
size_t msg_size = msg.role.size() + content.size();
|
||||
alloc_size += msg_size + (msg_size / 4); // == msg_size * 1.25 but avoiding float ops
|
||||
alloc_size += (msg.role.size() + content.size()) * 1.25;
|
||||
}
|
||||
|
||||
std::vector<char> buf(alloc_size);
|
||||
@@ -3404,11 +3053,6 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
||||
}
|
||||
|
||||
// for safety, we check the result again
|
||||
if (res < 0 || (size_t) res > buf.size()) {
|
||||
throw std::runtime_error("failed to apply chat template, try using --jinja");
|
||||
}
|
||||
|
||||
common_chat_params params;
|
||||
params.prompt = std::string(buf.data(), res);
|
||||
if (!inputs.json_schema.empty()) {
|
||||
@@ -3495,24 +3139,6 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
|
||||
common_chat_parse_lfm2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_MINIMAX_M2:
|
||||
common_chat_parse_minimax_m2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GLM_4_5:
|
||||
common_chat_parse_glm_4_5(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_KIMI_K2:
|
||||
common_chat_parse_kimi_k2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
|
||||
common_chat_parse_qwen3_coder_xml(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5:
|
||||
common_chat_parse_apriel_1_5(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
|
||||
common_chat_parse_xiaomi_mimo(builder);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||
}
|
||||
|
||||
@@ -117,12 +117,6 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||
COMMON_CHAT_FORMAT_APERTUS,
|
||||
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
|
||||
COMMON_CHAT_FORMAT_GLM_4_5,
|
||||
COMMON_CHAT_FORMAT_MINIMAX_M2,
|
||||
COMMON_CHAT_FORMAT_KIMI_K2,
|
||||
COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
|
||||
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
||||
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
@@ -59,14 +60,6 @@
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||
|
||||
common_time_meas::~common_time_meas() {
|
||||
if (t_start_us >= 0) {
|
||||
t_acc += ggml_time_us() - t_start_us;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// CPU utils
|
||||
//
|
||||
@@ -362,7 +355,11 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||
}
|
||||
|
||||
void common_init() {
|
||||
llama_log_set(common_log_default_callback, NULL);
|
||||
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
||||
common_log_add(common_log_main(), level, "%s", text);
|
||||
}
|
||||
}, NULL);
|
||||
|
||||
#ifdef NDEBUG
|
||||
const char * build_type = "";
|
||||
@@ -911,39 +908,6 @@ std::string fs_get_cache_file(const std::string & filename) {
|
||||
return cache_directory + filename;
|
||||
}
|
||||
|
||||
std::vector<common_file_info> fs_list_files(const std::string & path) {
|
||||
std::vector<common_file_info> files;
|
||||
if (path.empty()) return files;
|
||||
|
||||
std::filesystem::path dir(path);
|
||||
if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
|
||||
return files;
|
||||
}
|
||||
|
||||
for (const auto & entry : std::filesystem::directory_iterator(dir)) {
|
||||
try {
|
||||
// Only include regular files (skip directories)
|
||||
const auto & p = entry.path();
|
||||
if (std::filesystem::is_regular_file(p)) {
|
||||
common_file_info info;
|
||||
info.path = p.string();
|
||||
info.name = p.filename().string();
|
||||
try {
|
||||
info.size = static_cast<size_t>(std::filesystem::file_size(p));
|
||||
} catch (const std::filesystem::filesystem_error &) {
|
||||
info.size = 0;
|
||||
}
|
||||
files.push_back(std::move(info));
|
||||
}
|
||||
} catch (const std::filesystem::filesystem_error &) {
|
||||
// skip entries we cannot inspect
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return files;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Model utils
|
||||
|
||||
@@ -2,15 +2,17 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ggml-opt.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <cmath>
|
||||
|
||||
#include "ggml-opt.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#define DIRECTORY_SEPARATOR '\\'
|
||||
@@ -28,15 +30,6 @@
|
||||
|
||||
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
||||
|
||||
struct common_time_meas {
|
||||
common_time_meas(int64_t & t_acc, bool disable = false);
|
||||
~common_time_meas();
|
||||
|
||||
const int64_t t_start_us;
|
||||
|
||||
int64_t & t_acc;
|
||||
};
|
||||
|
||||
struct common_adapter_lora_info {
|
||||
std::string path;
|
||||
float scale;
|
||||
@@ -467,8 +460,7 @@ struct common_params {
|
||||
float slot_prompt_similarity = 0.1f;
|
||||
|
||||
// batched-bench params
|
||||
bool is_pp_shared = false;
|
||||
bool is_tg_separate = false;
|
||||
bool is_pp_shared = false;
|
||||
|
||||
std::vector<int32_t> n_pp;
|
||||
std::vector<int32_t> n_tg;
|
||||
@@ -619,13 +611,6 @@ bool fs_create_directory_with_parents(const std::string & path);
|
||||
std::string fs_get_cache_directory();
|
||||
std::string fs_get_cache_file(const std::string & filename);
|
||||
|
||||
struct common_file_info {
|
||||
std::string path;
|
||||
std::string name;
|
||||
size_t size = 0; // in bytes
|
||||
};
|
||||
std::vector<common_file_info> fs_list_files(const std::string & path);
|
||||
|
||||
//
|
||||
// Model utils
|
||||
//
|
||||
|
||||
1072
common/download.cpp
1072
common/download.cpp
File diff suppressed because it is too large
Load Diff
@@ -1,55 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
struct common_params_model;
|
||||
|
||||
//
|
||||
// download functionalities
|
||||
//
|
||||
|
||||
struct common_cached_model_info {
|
||||
std::string manifest_path;
|
||||
std::string user;
|
||||
std::string model;
|
||||
std::string tag;
|
||||
size_t size = 0; // GGUF size in bytes
|
||||
std::string to_string() const {
|
||||
return user + "/" + model + ":" + tag;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_hf_file_res {
|
||||
std::string repo; // repo name with ":tag" removed
|
||||
std::string ggufFile;
|
||||
std::string mmprojFile;
|
||||
};
|
||||
|
||||
/**
|
||||
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
|
||||
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
|
||||
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
|
||||
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
|
||||
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
|
||||
*
|
||||
* Return pair of <repo, file> (with "repo" already having tag removed)
|
||||
*
|
||||
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
|
||||
*/
|
||||
common_hf_file_res common_get_hf_file(
|
||||
const std::string & hf_repo_with_tag,
|
||||
const std::string & bearer_token,
|
||||
bool offline);
|
||||
|
||||
// returns true if download succeeded
|
||||
bool common_download_model(
|
||||
const common_params_model & model,
|
||||
const std::string & bearer_token,
|
||||
bool offline);
|
||||
|
||||
// returns list of cached models
|
||||
std::vector<common_cached_model_info> common_list_cached_models();
|
||||
|
||||
// resolve and download model from Docker registry
|
||||
// return local path to downloaded model file
|
||||
std::string common_docker_resolve_model(const std::string & docker);
|
||||
@@ -297,25 +297,8 @@ bool common_json_parse(
|
||||
it = temptative_end;
|
||||
return true;
|
||||
}
|
||||
// handle unclosed top-level primitive
|
||||
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
|
||||
std::string str(it, temptative_end);
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;
|
||||
if (can_parse(str + "\"")) {
|
||||
// Was inside an string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
|
||||
// Was inside an string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
|
||||
} else {
|
||||
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||
// fprintf(stderr, "Closing: TODO\n");
|
||||
return false;
|
||||
}
|
||||
out.json = json::parse(str);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
}
|
||||
// TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||
// fprintf(stderr, "Closing: TODO\n");
|
||||
return false;
|
||||
}
|
||||
out.json = json::parse(it, end);
|
||||
|
||||
@@ -303,8 +303,6 @@ static std::string format_literal(const std::string & literal) {
|
||||
return "\"" + escaped + "\"";
|
||||
}
|
||||
|
||||
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
|
||||
|
||||
class SchemaConverter {
|
||||
private:
|
||||
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
||||
|
||||
@@ -18,6 +18,4 @@ struct common_grammar_options {
|
||||
bool dotall = false;
|
||||
};
|
||||
|
||||
std::string gbnf_format_literal(const std::string & literal);
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||
|
||||
@@ -442,9 +442,3 @@ void common_log_set_prefix(struct common_log * log, bool prefix) {
|
||||
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
||||
log->set_timestamps(timestamps);
|
||||
}
|
||||
|
||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
||||
common_log_add(common_log_main(), level, "%s", text);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,8 +36,6 @@ extern int common_log_verbosity_thold;
|
||||
|
||||
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
||||
|
||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
|
||||
|
||||
// the common_log uses an internal worker thread to print/write log messages
|
||||
// when the worker thread is paused, incoming log messages are discarded
|
||||
struct common_log;
|
||||
|
||||
@@ -3,10 +3,9 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
|
||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||
// TODO: deduplicate with llama-impl.h
|
||||
@@ -113,13 +112,6 @@ struct common_sampler {
|
||||
|
||||
llama_token_data_array cur_p;
|
||||
|
||||
void reset() {
|
||||
prev.clear();
|
||||
|
||||
llama_sampler_reset(grmr);
|
||||
llama_sampler_reset(chain);
|
||||
}
|
||||
|
||||
void set_logits(struct llama_context * ctx, int idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
|
||||
@@ -136,12 +128,6 @@ struct common_sampler {
|
||||
|
||||
cur_p = { cur.data(), cur.size(), -1, false };
|
||||
}
|
||||
|
||||
common_time_meas tm() {
|
||||
return common_time_meas(t_total_us, params.no_perf);
|
||||
}
|
||||
|
||||
mutable int64_t t_total_us = 0;
|
||||
};
|
||||
|
||||
std::string common_params_sampling::print() const {
|
||||
@@ -312,8 +298,6 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
||||
}
|
||||
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
if (accept_grammar) {
|
||||
llama_sampler_accept(gsmpl->grmr, token);
|
||||
}
|
||||
@@ -324,7 +308,9 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
|
||||
}
|
||||
|
||||
void common_sampler_reset(struct common_sampler * gsmpl) {
|
||||
gsmpl->reset();
|
||||
llama_sampler_reset(gsmpl->grmr);
|
||||
|
||||
llama_sampler_reset(gsmpl->chain);
|
||||
}
|
||||
|
||||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||
@@ -341,54 +327,16 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
||||
// TODO: measure grammar performance
|
||||
|
||||
const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0;
|
||||
|
||||
llama_perf_sampler_data data_smpl;
|
||||
llama_perf_context_data data_ctx;
|
||||
|
||||
memset(&data_smpl, 0, sizeof(data_smpl));
|
||||
memset(&data_ctx, 0, sizeof(data_ctx));
|
||||
|
||||
if (gsmpl) {
|
||||
auto & data = data_smpl;
|
||||
|
||||
data = llama_perf_sampler(gsmpl->chain);
|
||||
|
||||
// note: the sampling time includes the samplers time + extra time spent in common/sampling
|
||||
LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms);
|
||||
LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample);
|
||||
llama_perf_sampler_print(gsmpl->chain);
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
auto & data = data_ctx;
|
||||
|
||||
data = llama_perf_context(ctx);
|
||||
|
||||
const double t_end_ms = 1e-3 * ggml_time_us();
|
||||
|
||||
const double t_total_ms = t_end_ms - data.t_start_ms;
|
||||
const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms);
|
||||
const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms;
|
||||
|
||||
LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
|
||||
LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
|
||||
LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
||||
LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
||||
LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc);
|
||||
LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused);
|
||||
|
||||
llama_perf_context_print(ctx);
|
||||
llama_memory_breakdown_print(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
auto & grmr = gsmpl->grmr;
|
||||
@@ -480,8 +428,6 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||
// helpers
|
||||
|
||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
auto * res = &gsmpl->cur_p;
|
||||
|
||||
if (do_sort && !res->sorted) {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -139,7 +139,6 @@ models = [
|
||||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||
{"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
|
||||
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
|
||||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
|
||||
|
||||
@@ -277,15 +277,10 @@ def parse_args() -> argparse.Namespace:
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_hparams_from_hf(hf_model_id: str) -> tuple[dict[str, Any], Path | None]:
|
||||
from huggingface_hub import try_to_load_from_cache
|
||||
|
||||
def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
|
||||
# normally, adapter does not come with base model config, we need to load it from AutoConfig
|
||||
config = AutoConfig.from_pretrained(hf_model_id)
|
||||
cache_dir = try_to_load_from_cache(hf_model_id, "config.json")
|
||||
cache_dir = Path(cache_dir).parent if isinstance(cache_dir, str) else None
|
||||
|
||||
return config.to_dict(), cache_dir
|
||||
return config.to_dict()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -330,13 +325,13 @@ if __name__ == '__main__':
|
||||
# load base model
|
||||
if base_model_id is not None:
|
||||
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
|
||||
hparams, dir_base_model = load_hparams_from_hf(base_model_id)
|
||||
hparams = load_hparams_from_hf(base_model_id)
|
||||
elif dir_base_model is None:
|
||||
if "base_model_name_or_path" in lparams:
|
||||
model_id = lparams["base_model_name_or_path"]
|
||||
logger.info(f"Loading base model from Hugging Face: {model_id}")
|
||||
try:
|
||||
hparams, dir_base_model = load_hparams_from_hf(model_id)
|
||||
hparams = load_hparams_from_hf(model_id)
|
||||
except OSError as e:
|
||||
logger.error(f"Failed to load base model config: {e}")
|
||||
logger.error("Please try downloading the base model and add its path to --base")
|
||||
@@ -485,7 +480,6 @@ if __name__ == '__main__':
|
||||
dir_lora_model=dir_lora,
|
||||
lora_alpha=alpha,
|
||||
hparams=hparams,
|
||||
remote_hf_model_id=base_model_id,
|
||||
)
|
||||
|
||||
logger.info("Exporting model...")
|
||||
|
||||
@@ -313,12 +313,7 @@ Converting the matmul weight format from ND to NZ to improve performance. Enable
|
||||
|
||||
### GGML_CANN_ACL_GRAPH
|
||||
|
||||
Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default. This option is only effective if `USE_ACL_GRAPH` was enabled at compilation time. To enable it, recompile using:
|
||||
|
||||
```sh
|
||||
cmake -B build -DGGML_CANN=on -DCMAKE_BUILD_TYPE=release -DUSE_ACL_GRAPH=ON
|
||||
cmake --build build --config release
|
||||
```
|
||||
Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default.
|
||||
|
||||
### GGML_CANN_GRAPH_CACHE_CAPACITY
|
||||
|
||||
|
||||
109
docs/ops.md
109
docs/ops.md
@@ -14,108 +14,103 @@ Legend:
|
||||
|
||||
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
|
||||
|-----------|------|------|------|------|------|------|------|------|------|
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | 🟡 | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
||||
21200
docs/ops/CPU.csv
21200
docs/ops/CPU.csv
File diff suppressed because it is too large
Load Diff
23118
docs/ops/CUDA.csv
23118
docs/ops/CUDA.csv
File diff suppressed because it is too large
Load Diff
7174
docs/ops/SYCL.csv
7174
docs/ops/SYCL.csv
File diff suppressed because it is too large
Load Diff
18908
docs/ops/Vulkan.csv
18908
docs/ops/Vulkan.csv
File diff suppressed because it is too large
Load Diff
@@ -6,54 +6,8 @@ More Info:
|
||||
- https://github.com/ggml-org/llama.cpp/pull/14644
|
||||
- https://github.com/ggml-org/llama.cpp/pull/14771
|
||||
|
||||
## Parameters
|
||||
The diffusion CLI supports various parameters to control the generation process:
|
||||
|
||||
### Core Diffusion Parameters
|
||||
- `--diffusion-steps`: Number of diffusion steps (default: 256)
|
||||
- `--diffusion-algorithm`: Algorithm for token selection
|
||||
- `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
|
||||
- `1`: ENTROPY_BASED - Entropy-based selection
|
||||
- `2`: MARGIN_BASED - Margin-based selection
|
||||
- `3`: RANDOM - Random selection
|
||||
- `4`: CONFIDENCE_BASED - Confidence-based selection (default)
|
||||
- More documentation here https://github.com/DreamLM/Dream
|
||||
- `--diffusion-visual`: Enable live visualization during generation
|
||||
Example of using Dream architechture: `llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual`
|
||||
|
||||
### Scheduling Parameters
|
||||
Choose one of the following scheduling methods:
|
||||
Example of using LLaDA architechture: `llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual`
|
||||
|
||||
**Timestep-based scheduling:**
|
||||
- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001)
|
||||
|
||||
**Block-based scheduling:**
|
||||
- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32)
|
||||
|
||||
### Sampling Parameters
|
||||
- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random)
|
||||
- `--top-k`: Top-k filtering for sampling
|
||||
- `--top-p`: Top-p (nucleus) filtering for sampling
|
||||
- `--seed`: Random seed for reproducibility
|
||||
|
||||
### Model Parameters
|
||||
- `-m`: Path to the GGUF model file
|
||||
- `-p`: Input prompt text
|
||||
- `-ub`: Maximum sequence length (ubatch size)
|
||||
- `-c`: Context size
|
||||
- `-b`: Batch size
|
||||
|
||||
### Examples
|
||||
#### Dream architechture:
|
||||
```
|
||||
llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual
|
||||
```
|
||||
|
||||
#### LLaDA architechture:
|
||||
```
|
||||
llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual
|
||||
```
|
||||
|
||||
#### RND1 architecture:
|
||||
```
|
||||
llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001
|
||||
```
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
|
||||
/**
|
||||
* This the arbitrary data which will be passed to each callback.
|
||||
@@ -37,23 +37,23 @@ static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
|
||||
return u.f;
|
||||
}
|
||||
|
||||
static float ggml_get_float_value(const uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
|
||||
static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
|
||||
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
|
||||
float v;
|
||||
if (type == GGML_TYPE_F16) {
|
||||
v = ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]);
|
||||
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
|
||||
} else if (type == GGML_TYPE_F32) {
|
||||
v = *(const float *) &data[i];
|
||||
v = *(float *) &data[i];
|
||||
} else if (type == GGML_TYPE_I64) {
|
||||
v = (float) *(const int64_t *) &data[i];
|
||||
v = (float) *(int64_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I32) {
|
||||
v = (float) *(const int32_t *) &data[i];
|
||||
v = (float) *(int32_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I16) {
|
||||
v = (float) *(const int16_t *) &data[i];
|
||||
v = (float) *(int16_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I8) {
|
||||
v = (float) *(const int8_t *) &data[i];
|
||||
v = (float) *(int8_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_BF16) {
|
||||
v = ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]);
|
||||
v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
@@ -25,17 +25,16 @@ if(GIT_EXE)
|
||||
)
|
||||
endif()
|
||||
|
||||
# Build the version string with optional dirty flag
|
||||
set(GGML_VERSION "${GGML_VERSION_BASE}")
|
||||
if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0)
|
||||
set(GGML_VERSION "${GGML_VERSION}-dirty")
|
||||
endif()
|
||||
|
||||
if(NOT GGML_BUILD_COMMIT)
|
||||
set(GGML_BUILD_COMMIT "unknown")
|
||||
endif()
|
||||
|
||||
# Build the commit string with optional dirty flag
|
||||
if(DEFINED GGML_GIT_DIRTY AND GGML_GIT_DIRTY EQUAL 1)
|
||||
set(GGML_BUILD_COMMIT "${GGML_BUILD_COMMIT}-dirty")
|
||||
endif()
|
||||
|
||||
include(CheckIncludeFileCXX)
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
@@ -169,7 +168,7 @@ option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
|
||||
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
|
||||
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
|
||||
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
|
||||
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
|
||||
option(GGML_VXE "ggml: enable vxe" ON)
|
||||
|
||||
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
|
||||
@@ -475,7 +475,6 @@ extern "C" {
|
||||
GGML_OP_COS,
|
||||
GGML_OP_SUM,
|
||||
GGML_OP_SUM_ROWS,
|
||||
GGML_OP_CUMSUM,
|
||||
GGML_OP_MEAN,
|
||||
GGML_OP_ARGMAX,
|
||||
GGML_OP_COUNT_EQUAL,
|
||||
@@ -531,8 +530,6 @@ extern "C" {
|
||||
GGML_OP_TIMESTEP_EMBEDDING,
|
||||
GGML_OP_ARGSORT,
|
||||
GGML_OP_LEAKY_RELU,
|
||||
GGML_OP_TRI,
|
||||
GGML_OP_FILL,
|
||||
|
||||
GGML_OP_FLASH_ATTN_EXT,
|
||||
GGML_OP_FLASH_ATTN_BACK,
|
||||
@@ -545,7 +542,6 @@ extern "C" {
|
||||
GGML_OP_RWKV_WKV6,
|
||||
GGML_OP_GATED_LINEAR_ATTN,
|
||||
GGML_OP_RWKV_WKV7,
|
||||
GGML_OP_SOLVE_TRI,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
@@ -580,8 +576,6 @@ extern "C" {
|
||||
GGML_UNARY_OP_HARDSWISH,
|
||||
GGML_UNARY_OP_HARDSIGMOID,
|
||||
GGML_UNARY_OP_EXP,
|
||||
GGML_UNARY_OP_EXPM1,
|
||||
GGML_UNARY_OP_SOFTPLUS,
|
||||
GGML_UNARY_OP_GELU_ERF,
|
||||
GGML_UNARY_OP_XIELU,
|
||||
GGML_UNARY_OP_FLOOR,
|
||||
@@ -626,13 +620,6 @@ extern "C" {
|
||||
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
||||
};
|
||||
|
||||
enum ggml_tri_type {
|
||||
GGML_TRI_TYPE_UPPER_DIAG = 0,
|
||||
GGML_TRI_TYPE_UPPER = 1,
|
||||
GGML_TRI_TYPE_LOWER_DIAG = 2,
|
||||
GGML_TRI_TYPE_LOWER = 3
|
||||
};
|
||||
|
||||
struct ggml_init_params {
|
||||
// memory pool
|
||||
size_t mem_size; // bytes
|
||||
@@ -970,22 +957,6 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_expm1(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_expm1_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_softplus(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_softplus_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_sin(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
@@ -1012,10 +983,6 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_cumsum(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// mean along rows
|
||||
GGML_API struct ggml_tensor * ggml_mean(
|
||||
struct ggml_context * ctx,
|
||||
@@ -2220,23 +2187,6 @@ extern "C" {
|
||||
int shift2,
|
||||
int shift3);
|
||||
|
||||
// Convert matrix into a triangular one (upper, strict upper, lower or strict lower) by writing
|
||||
// zeroes everywhere outside the masked area
|
||||
GGML_API struct ggml_tensor * ggml_tri(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_tri_type type);
|
||||
|
||||
// Fill tensor a with constant c
|
||||
GGML_API struct ggml_tensor * ggml_fill(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float c);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_fill_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float c);
|
||||
|
||||
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
||||
// timesteps: [N,]
|
||||
@@ -2406,27 +2356,6 @@ extern "C" {
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * state);
|
||||
|
||||
/* Solves a specific equation of the form Ax=B, where A is a triangular matrix
|
||||
* without zeroes on the diagonal (i.e. invertible).
|
||||
* B can have any number of columns, but must have the same number of rows as A
|
||||
* If A is [n, n] and B is [n, m], then the result will be [n, m] as well
|
||||
* Has O(n^3) complexity (unlike most matrix ops out there), so use on cases
|
||||
* where n > 100 sparingly, pre-chunk if necessary.
|
||||
*
|
||||
* If left = false, solves xA=B instead
|
||||
* If lower = false, assumes upper triangular instead
|
||||
* If uni = true, assumes diagonal of A to be all ones (will override actual values)
|
||||
*
|
||||
* TODO: currently only lower, right, non-unitriangular variant is implemented
|
||||
*/
|
||||
GGML_API struct ggml_tensor * ggml_solve_tri(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
bool left,
|
||||
bool lower,
|
||||
bool uni);
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
||||
|
||||
@@ -211,11 +211,6 @@ add_library(ggml-base
|
||||
ggml-quants.h
|
||||
gguf.cpp)
|
||||
|
||||
set_target_properties(ggml-base PROPERTIES
|
||||
VERSION ${GGML_VERSION}
|
||||
SOVERSION ${GGML_VERSION_MAJOR}
|
||||
)
|
||||
|
||||
target_include_directories(ggml-base PRIVATE .)
|
||||
if (GGML_BACKEND_DL)
|
||||
target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
|
||||
@@ -225,11 +220,6 @@ add_library(ggml
|
||||
ggml-backend-reg.cpp)
|
||||
add_library(ggml::ggml ALIAS ggml)
|
||||
|
||||
set_target_properties(ggml PROPERTIES
|
||||
VERSION ${GGML_VERSION}
|
||||
SOVERSION ${GGML_VERSION_MAJOR}
|
||||
)
|
||||
|
||||
if (GGML_BACKEND_DIR)
|
||||
if (NOT GGML_BACKEND_DL)
|
||||
message(FATAL_ERROR "GGML_BACKEND_DIR requires GGML_BACKEND_DL")
|
||||
@@ -269,12 +259,6 @@ function(ggml_add_backend_library backend)
|
||||
target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED)
|
||||
endif()
|
||||
|
||||
# Set versioning properties for all backend libraries
|
||||
set_target_properties(${backend} PROPERTIES
|
||||
VERSION ${GGML_VERSION}
|
||||
SOVERSION ${GGML_VERSION_MAJOR}
|
||||
)
|
||||
|
||||
if(NOT GGML_AVAILABLE_BACKENDS)
|
||||
set(GGML_AVAILABLE_BACKENDS "${backend}"
|
||||
CACHE INTERNAL "List of backends for cmake package")
|
||||
@@ -328,14 +312,6 @@ function(ggml_add_cpu_backend_variant tag_name)
|
||||
set(GGML_INTERNAL_${feat} OFF)
|
||||
endforeach()
|
||||
|
||||
foreach (feat ${ARGN})
|
||||
set(GGML_INTERNAL_${feat} ON)
|
||||
endforeach()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
|
||||
foreach (feat RVV)
|
||||
set(GGML_INTERNAL_${feat} OFF)
|
||||
endforeach()
|
||||
|
||||
foreach (feat ${ARGN})
|
||||
set(GGML_INTERNAL_${feat} ON)
|
||||
endforeach()
|
||||
@@ -410,13 +386,6 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
ggml_add_cpu_backend_variant(riscv64_0)
|
||||
ggml_add_cpu_backend_variant(riscv64_v RVV)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported RISC-V target OS: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
|
||||
@@ -1698,6 +1698,8 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
|
||||
GGML_ASSERT(sched);
|
||||
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
|
||||
|
||||
ggml_backend_sched_reset(sched);
|
||||
|
||||
ggml_backend_sched_synchronize(sched);
|
||||
|
||||
ggml_backend_sched_split_graph(sched, measure_graph);
|
||||
|
||||
2579
ggml/src/ggml-cann/Doxyfile
Executable file
2579
ggml/src/ggml-cann/Doxyfile
Executable file
File diff suppressed because it is too large
Load Diff
@@ -48,14 +48,15 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
|
||||
default:
|
||||
return ACL_DT_UNDEFINED;
|
||||
}
|
||||
return ACL_DT_UNDEFINED;
|
||||
}
|
||||
|
||||
acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne,
|
||||
size_t * nb,
|
||||
int64_t dims,
|
||||
aclFormat format,
|
||||
size_t offset) {
|
||||
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne,
|
||||
size_t * nb,
|
||||
int64_t dims,
|
||||
aclFormat format,
|
||||
size_t offset) {
|
||||
// If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be
|
||||
// added.
|
||||
int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2];
|
||||
@@ -86,20 +87,10 @@ acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
std::reverse(acl_ne, acl_ne + final_dims);
|
||||
std::reverse(acl_stride, acl_stride + final_dims);
|
||||
|
||||
aclTensor * raw = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride, elem_offset,
|
||||
format, &acl_storage_len, 1, tensor->data);
|
||||
aclTensor * acl_tensor = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
|
||||
elem_offset, format, &acl_storage_len, 1, tensor->data);
|
||||
|
||||
return acl_tensor_ptr(raw);
|
||||
}
|
||||
|
||||
acl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size) {
|
||||
aclIntArray * raw = aclCreateIntArray(value, size);
|
||||
return acl_int_array_ptr(raw);
|
||||
}
|
||||
|
||||
acl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType) {
|
||||
aclScalar * raw = aclCreateScalar(value, dataType);
|
||||
return acl_scalar_ptr(raw);
|
||||
return acl_tensor;
|
||||
}
|
||||
|
||||
bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1) {
|
||||
|
||||
@@ -23,13 +23,12 @@
|
||||
#ifndef CANN_ACL_TENSOR_H
|
||||
#define CANN_ACL_TENSOR_H
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#include <aclnn/aclnn_base.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
#include <aclnn/aclnn_base.h>
|
||||
#include "common.h"
|
||||
|
||||
/**
|
||||
* @brief Maps a ggml_type to its corresponding aclDataType.
|
||||
*
|
||||
@@ -44,20 +43,6 @@
|
||||
*/
|
||||
aclDataType ggml_cann_type_mapping(ggml_type type);
|
||||
|
||||
// Deleter for acl objects.
|
||||
template <typename T, aclError (*DestroyFunc)(const T *)> struct acl_deleter {
|
||||
void operator()(T * ptr) const noexcept {
|
||||
if (ptr) {
|
||||
ACL_CHECK(DestroyFunc(ptr));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using acl_tensor_ptr = std::unique_ptr<aclTensor, acl_deleter<aclTensor, aclDestroyTensor>>;
|
||||
using acl_int_array_ptr = std::unique_ptr<aclIntArray, acl_deleter<aclIntArray, aclDestroyIntArray>>;
|
||||
using acl_scalar_ptr = std::unique_ptr<aclScalar, acl_deleter<aclScalar, aclDestroyScalar>>;
|
||||
using acl_tensor_list_ptr = std::unique_ptr<aclTensorList, acl_deleter<aclTensorList, aclDestroyTensorList>>;
|
||||
|
||||
/**
|
||||
* @brief Creates an ACL tensor from a ggml_tensor with optional shape.
|
||||
*
|
||||
@@ -77,12 +62,12 @@ using acl_tensor_list_ptr = std::unique_ptr<aclTensorList, acl_deleter<aclTensor
|
||||
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
|
||||
* @return Pointer to the created ACL tensor.
|
||||
*/
|
||||
acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne = nullptr,
|
||||
size_t * nb = nullptr,
|
||||
int64_t dims = 0,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0);
|
||||
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne = nullptr,
|
||||
size_t * nb = nullptr,
|
||||
int64_t dims = 0,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0);
|
||||
|
||||
/**
|
||||
* @brief Template for creating an ACL tensor from provided parameters. typename TYPE
|
||||
@@ -105,14 +90,14 @@ acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
* @return Pointer to the created ACL tensor.
|
||||
*/
|
||||
template <typename TYPE>
|
||||
acl_tensor_ptr ggml_cann_create_tensor(void * data_ptr,
|
||||
aclDataType dtype,
|
||||
TYPE type_size,
|
||||
int64_t * ne,
|
||||
TYPE * nb,
|
||||
int64_t dims,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0) {
|
||||
aclTensor * ggml_cann_create_tensor(void * data_ptr,
|
||||
aclDataType dtype,
|
||||
TYPE type_size,
|
||||
int64_t * ne,
|
||||
TYPE * nb,
|
||||
int64_t dims,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0) {
|
||||
int64_t tmp_ne[GGML_MAX_DIMS * 2];
|
||||
int64_t tmp_stride[GGML_MAX_DIMS * 2];
|
||||
|
||||
@@ -129,75 +114,10 @@ acl_tensor_ptr ggml_cann_create_tensor(void * data_ptr,
|
||||
std::reverse(tmp_ne, tmp_ne + dims);
|
||||
std::reverse(tmp_stride, tmp_stride + dims);
|
||||
|
||||
aclTensor * raw =
|
||||
aclTensor * acl_tensor =
|
||||
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr);
|
||||
|
||||
return acl_tensor_ptr(raw);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create an ACL int array resource wrapped in a smart pointer.
|
||||
*
|
||||
* This function constructs an aclIntArray from the provided int64_t values
|
||||
* and returns it as an acl_int_array_ptr (a std::unique_ptr with a custom
|
||||
* deleter). The returned pointer owns the ACL resource and will automatically
|
||||
* destroy it via aclDestroyIntArray().
|
||||
*
|
||||
* @param value Pointer to the int64_t elements.
|
||||
* @param size Number of elements in value.
|
||||
*
|
||||
* @return A smart pointer managing the created ACL int array.
|
||||
*/
|
||||
acl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size);
|
||||
|
||||
/**
|
||||
* @brief Create an ACL scalar resource wrapped in a smart pointer.
|
||||
*
|
||||
* This function constructs an aclScalar from the raw value pointer and ACL
|
||||
* data type, then returns it as an acl_scalar_ptr (a std::unique_ptr with
|
||||
* a custom deleter). The returned pointer owns the ACL scalar and will
|
||||
* automatically destroy it via aclDestroyScalar().
|
||||
*
|
||||
* @param value Pointer to the raw scalar memory.
|
||||
* @param dataType ACL data type of the scalar.
|
||||
*
|
||||
* @return A smart pointer managing the created ACL scalar.
|
||||
*/
|
||||
acl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType);
|
||||
|
||||
/**
|
||||
* @brief Create an ACL tensor list from multiple tensor smart pointers.
|
||||
*
|
||||
* This function accepts a variadic list of acl_tensor_ptr (a unique_ptr with
|
||||
* custom deleter) and produces an aclTensorList using aclCreateTensorList().
|
||||
*
|
||||
* The lifecycle management of the tensor objects changes as follows:
|
||||
* - aclCreateTensorList() takes ownership of the tensors
|
||||
* - Each input smart pointer releases ownership using release()
|
||||
* - As a result, the tensors will NOT be destroyed by unique_ptr
|
||||
* - Instead, they will be destroyed when aclDestroyTensorList() is called
|
||||
*
|
||||
* This ensures correct ownership transfer and prevents double-free situations.
|
||||
*
|
||||
* @param acl_tensor_ptr Variadic template parameter; each argument must be
|
||||
* a unique_ptr-like type supporting get() and release().
|
||||
*
|
||||
* @param tensors Variadic list of acl_tensor_ptr objects. Ownership of
|
||||
* each tensor is transferred away from these smart pointers.
|
||||
*
|
||||
* @return A smart pointer (acl_tensor_list_ptr) owning the created ACL tensor list.
|
||||
*
|
||||
* @note This implementation is C++11 compatible. The ownership-release process is
|
||||
* executed using a pack expansion inside an initializer list.
|
||||
*/
|
||||
template <typename... acl_tensor_ptr> acl_tensor_list_ptr ggml_cann_create_tensor_list(acl_tensor_ptr &&... tensors) {
|
||||
aclTensor * raw_tensors[] = { tensors.get()... };
|
||||
aclTensorList * raw = aclCreateTensorList(raw_tensors, sizeof...(tensors));
|
||||
// aclTensor will release by aclTensorList, so release ownership without
|
||||
// destroying the tensor
|
||||
int dummy[] = { (tensors.release(), 0)... };
|
||||
GGML_UNUSED(dummy);
|
||||
return acl_tensor_list_ptr(raw);
|
||||
return acl_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,35 +23,31 @@
|
||||
#ifndef CANN_ACLNN_OPS
|
||||
#define CANN_ACLNN_OPS
|
||||
|
||||
#include "acl_tensor.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <unordered_set>
|
||||
#include <functional>
|
||||
#include <aclnnop/aclnn_abs.h>
|
||||
#include <aclnnop/aclnn_neg.h>
|
||||
#include <aclnnop/aclnn_exp.h>
|
||||
#include <aclnnop/aclnn_arange.h>
|
||||
#include <aclnnop/aclnn_argsort.h>
|
||||
#include <aclnnop/aclnn_cat.h>
|
||||
#include <aclnnop/aclnn_clamp.h>
|
||||
#include <aclnnop/aclnn_cos.h>
|
||||
#include <aclnnop/aclnn_exp.h>
|
||||
#include <aclnnop/aclnn_gelu.h>
|
||||
#include <aclnnop/aclnn_gelu_v2.h>
|
||||
#include <aclnnop/aclnn_sigmoid.h>
|
||||
#include <aclnnop/aclnn_hardsigmoid.h>
|
||||
#include <aclnnop/aclnn_hardswish.h>
|
||||
#include <aclnnop/aclnn_leaky_relu.h>
|
||||
#include <aclnnop/aclnn_log.h>
|
||||
#include <aclnnop/aclnn_logsoftmax.h>
|
||||
#include <aclnnop/aclnn_neg.h>
|
||||
#include <aclnnop/aclnn_norm.h>
|
||||
#include <aclnnop/aclnn_relu.h>
|
||||
#include <aclnnop/aclnn_sigmoid.h>
|
||||
#include <aclnnop/aclnn_sign.h>
|
||||
#include <aclnnop/aclnn_silu.h>
|
||||
#include <aclnnop/aclnn_sin.h>
|
||||
#include <aclnnop/aclnn_sqrt.h>
|
||||
#include <aclnnop/aclnn_tanh.h>
|
||||
|
||||
#include <functional>
|
||||
#include <unordered_set>
|
||||
#include <aclnnop/aclnn_sqrt.h>
|
||||
#include <aclnnop/aclnn_sin.h>
|
||||
#include <aclnnop/aclnn_cos.h>
|
||||
#include <aclnnop/aclnn_log.h>
|
||||
#include <aclnnop/aclnn_sign.h>
|
||||
#include "acl_tensor.h"
|
||||
#include "common.h"
|
||||
|
||||
/**
|
||||
* @brief Repeats a ggml tensor along each dimension to match the dimensions
|
||||
@@ -191,66 +187,6 @@ void ggml_cann_argsort(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
*/
|
||||
void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the L2 Normalization for a ggml tensor using the CANN
|
||||
* backend.
|
||||
*
|
||||
* @details This function applies the L2 Normalization operation on the
|
||||
* input tensor `src` and stores the result in the destination tensor
|
||||
* `dst`. L2 Normalization scales the input tensor such that the
|
||||
* L2 norm along the specified dimension equals 1. This operation
|
||||
* is commonly used in neural networks for feature normalization
|
||||
* and vector scaling.
|
||||
* The operation is defined as:
|
||||
* \f[
|
||||
* \text{out} = \frac{x}{\sqrt{\sum{x^2}}}
|
||||
* \f]
|
||||
* The normalization is performed along the last dimension by default.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the normalized values will be stored.
|
||||
* @attention The normalization is performed along the last dimension of the
|
||||
* input tensor by default.
|
||||
*/
|
||||
void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the Cross Entropy Loss for a ggml tensor using the CANN
|
||||
* backend.
|
||||
*
|
||||
* @details This function computes the cross entropy loss between the predicted
|
||||
* logits and target probability distributions. The operation follows
|
||||
* the same computation pattern as the CPU implementation:
|
||||
* 1. Applies log_softmax to the logits along the class dimension
|
||||
* 2. Element-wise multiplication with target distributions
|
||||
* 3. Summation along the class dimension to get per-sample losses
|
||||
* 4. Global summation and scaling by -1/nr to get final loss
|
||||
*
|
||||
* The computation can be expressed as:
|
||||
* \f[
|
||||
* \text{loss} = -\frac{1}{N} \sum_{i=1}^{N} \sum_{j=1}^{C} y_{ij} \cdot \log(\text{softmax}(x_{ij}))
|
||||
* \f]
|
||||
* where \f$N\f$ is the total number of samples, \f$C\f$ is the number
|
||||
* of classes, \f$x\f$ are the logits, and \f$y\f$ are the target
|
||||
* probability distributions.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the computed loss will be stored.
|
||||
* This should be a scalar tensor containing the final loss value.
|
||||
*
|
||||
* @note This implementation computes cross entropy between probability
|
||||
* distributions, not the typical classification cross entropy that
|
||||
* expects class indices as targets. Both input tensors (src0 and src1)
|
||||
* should have the same shape and represent probability distributions
|
||||
* over the class dimension.
|
||||
* @note The function expects two source tensors:
|
||||
* - dst->src[0]: Logits tensor (before softmax)
|
||||
* - dst->src[1]: Target probability distributions tensor
|
||||
* @note The computation is performed using CANN backend operators including
|
||||
* LogSoftmax, Mul, ReduceSum, and Muls for the final scaling.
|
||||
*/
|
||||
void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the Group Normalization for a ggml tensor using the CANN
|
||||
* backend.
|
||||
@@ -690,12 +626,12 @@ void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor *
|
||||
* @param acl_src1 Output pointer to the created ACL tensor corresponding to src1.
|
||||
* @param acl_dst Output pointer to the created ACL tensor corresponding to dst.
|
||||
*/
|
||||
void bcast_shape(ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
acl_tensor_ptr & acl_src0,
|
||||
acl_tensor_ptr & acl_src1,
|
||||
acl_tensor_ptr & acl_dst);
|
||||
void bcast_shape(ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
aclTensor ** acl_src0,
|
||||
aclTensor ** acl_src1,
|
||||
aclTensor ** acl_dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the 1D transposed convolution (deconvolution) of a ggml
|
||||
@@ -875,6 +811,83 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
|
||||
(vec.emplace_back(make_acl_resource(args)), ...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Task class that wraps the execution of an aclnn function call.
|
||||
*/
|
||||
class aclnn_task : public cann_task {
|
||||
public:
|
||||
aclnn_task(aclnn_func_t aclnn_func,
|
||||
void * workspace_addr,
|
||||
uint64_t workspace_size,
|
||||
aclOpExecutor * executor,
|
||||
aclrtStream stream) :
|
||||
aclnn_func_(aclnn_func),
|
||||
workspace_addr_(workspace_addr),
|
||||
workspace_size_(workspace_size),
|
||||
executor_(executor),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override { ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_)); }
|
||||
private:
|
||||
aclnn_func_t aclnn_func_;
|
||||
void * workspace_addr_;
|
||||
uint64_t workspace_size_;
|
||||
aclOpExecutor * executor_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class that releases ACL resources after usage.
|
||||
*/
|
||||
class release_resource_task : public cann_task {
|
||||
public:
|
||||
release_resource_task(std::vector<any_acl_resource> && resources) { resource_ = std::move(resources); }
|
||||
|
||||
virtual void run_task() override { resource_.clear(); }
|
||||
private:
|
||||
std::vector<any_acl_resource> resource_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class for performing asynchronous memory copy operations.
|
||||
*/
|
||||
class async_memcpy_task : public cann_task {
|
||||
public:
|
||||
async_memcpy_task(void * dst, const void * src, size_t size, aclrtMemcpyKind kind, aclrtStream stream) :
|
||||
dst_(dst),
|
||||
src_(src),
|
||||
size_(size),
|
||||
kind_(kind),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override { ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_)); }
|
||||
private:
|
||||
void * dst_;
|
||||
const void * src_;
|
||||
size_t size_;
|
||||
aclrtMemcpyKind kind_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class for performing asynchronous memory set operations.
|
||||
*/
|
||||
class async_memset_task : public cann_task {
|
||||
public:
|
||||
async_memset_task(void * buffer, size_t size, int32_t value, aclrtStream stream) :
|
||||
buffer_(buffer),
|
||||
size_(size),
|
||||
value_(value),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override { ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_)); }
|
||||
private:
|
||||
void * buffer_;
|
||||
size_t size_;
|
||||
int32_t value_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Launches an asynchronous task using the memory allocator.
|
||||
*
|
||||
@@ -893,20 +906,95 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
|
||||
* same stream are executed in queue order.
|
||||
*/
|
||||
|
||||
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
|
||||
do { \
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor * executor; \
|
||||
void * workspaceAddr = nullptr; \
|
||||
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
|
||||
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
|
||||
if (workspaceSize > 0) { \
|
||||
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
|
||||
workspaceAddr = workspace_allocator.get(); \
|
||||
} \
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
|
||||
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
|
||||
do { \
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor * executor; \
|
||||
void * workspaceAddr = nullptr; \
|
||||
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
|
||||
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
|
||||
if (workspaceSize > 0) { \
|
||||
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
|
||||
workspaceAddr = workspace_allocator.get(); \
|
||||
} \
|
||||
if (CTX.async_mode) { \
|
||||
auto task = \
|
||||
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, executor, CTX.stream()); \
|
||||
CTX.task_queue.submit_task(std::move(task)); \
|
||||
} else { \
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
/**
|
||||
* @brief Registers and releases multiple ACL resources, optionally deferring the release
|
||||
* using a task.
|
||||
*
|
||||
* @tparam Args Types of the ACL resources.
|
||||
* @param ctx Backend context which manages task submission and async mode.
|
||||
* @param args Pointers to ACL resources to be released.
|
||||
*/
|
||||
template <typename... Args> void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
|
||||
std::vector<any_acl_resource> resources;
|
||||
register_acl_resources(resources, std::forward<Args>(args)...);
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<release_resource_task>(std::move(resources));
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs an asynchronous memory copy operation, optionally deferred via task submission.
|
||||
*
|
||||
* @param ctx Backend context containing stream and async configuration.
|
||||
* @param dst Destination memory address.
|
||||
* @param src Source memory address.
|
||||
* @param len Size of memory to copy (in bytes).
|
||||
* @param kind Type of memory copy (host-to-device, device-to-host, etc).
|
||||
*/
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx,
|
||||
void * dst,
|
||||
const void * src,
|
||||
size_t len,
|
||||
aclrtMemcpyKind kind) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx.stream());
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream()));
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx,
|
||||
void * dst,
|
||||
const void * src,
|
||||
size_t len,
|
||||
aclrtMemcpyKind kind) {
|
||||
if (ctx->async_mode) {
|
||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx->stream());
|
||||
ctx->task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs an asynchronous memory set operation, optionally deferred via task submission.
|
||||
*
|
||||
* @param ctx Backend context containing stream and async configuration.
|
||||
* @param buffer Memory buffer to be set.
|
||||
* @param size Size of the memory buffer (in bytes).
|
||||
* @param value Value to set in the buffer.
|
||||
*/
|
||||
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer, size_t size, int value) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<async_memset_task>(buffer, size, value, ctx.stream());
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs sparse expert-based matrix multiplication using the CANN backend.
|
||||
*
|
||||
@@ -979,11 +1067,15 @@ template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & c
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
acl_tensor_ptr acl_src0, acl_src1, acl_dst;
|
||||
aclTensor * acl_src0;
|
||||
aclTensor * acl_src1;
|
||||
aclTensor * acl_dst;
|
||||
|
||||
// Need bcast
|
||||
bcast_shape(src0, src1, dst, acl_src0, acl_src1, acl_dst);
|
||||
binary_op(ctx, acl_src0.get(), acl_src1.get(), acl_dst.get());
|
||||
bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
|
||||
binary_op(ctx, acl_src0, acl_src1, acl_dst);
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -993,7 +1085,7 @@ template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & c
|
||||
* and stores the result in the destination tensor.
|
||||
*
|
||||
* @tparam unary_op A callable with the signature:
|
||||
* void(ggml_backend_cann_context&, aclTensor *, aclTensor *)
|
||||
* void(ggml_backend_cann_context&, aclTensor*, aclTensor*)
|
||||
* where the first aclTensor is the source and the second is the destination.
|
||||
* @param ctx The CANN backend context for managing resources and execution.
|
||||
* @param dst The destination tensor. Its src[0] is treated as the input tensor.
|
||||
@@ -1002,10 +1094,11 @@ template <void unary_op(ggml_backend_cann_context &, aclTensor *, aclTensor *)>
|
||||
void ggml_cann_op_unary(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src = dst->src[0];
|
||||
|
||||
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
|
||||
aclTensor * acl_src = ggml_cann_create_tensor(src);
|
||||
aclTensor * acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
unary_op(ctx, acl_src.get(), acl_dst.get());
|
||||
unary_op(ctx, acl_src, acl_dst);
|
||||
ggml_cann_release_resources(ctx, acl_src, acl_dst);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -23,26 +23,26 @@
|
||||
#ifndef CANN_COMMON_H
|
||||
#define CANN_COMMON_H
|
||||
|
||||
#include "../ggml-impl.h"
|
||||
#include "../include/ggml-cann.h"
|
||||
#include "../include/ggml.h"
|
||||
|
||||
#include <acl/acl.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <cstdio>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <list>
|
||||
|
||||
#include "../include/ggml-cann.h"
|
||||
#include "../include/ggml.h"
|
||||
#include "../ggml-impl.h"
|
||||
|
||||
#define MATRIX_ROW_PADDING 512
|
||||
#define GGML_CANN_MAX_STREAMS 8
|
||||
@@ -214,6 +214,130 @@ struct ggml_cann_pool_alloc {
|
||||
ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Function pointer type for ACLNN operator calls.
|
||||
*/
|
||||
using aclnn_func_t = aclnnStatus (*)(void *, uint64_t, aclOpExecutor *, aclrtStream);
|
||||
|
||||
/**
|
||||
* @brief Base class for all CANN tasks to be submitted to the task queue.
|
||||
*
|
||||
* Users should override the run_task() method with actual task logic.
|
||||
*/
|
||||
class cann_task {
|
||||
public:
|
||||
virtual void run_task() {}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
|
||||
*/
|
||||
class cann_task_queue {
|
||||
public:
|
||||
/**
|
||||
* @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
|
||||
*
|
||||
* @param capacity Queue capacity. Must be a power of 2.
|
||||
* @param device Target device ID (used for context setting).
|
||||
*/
|
||||
explicit cann_task_queue(size_t capacity, int32_t device) :
|
||||
buffer_(capacity),
|
||||
capacity_(capacity),
|
||||
head_(0),
|
||||
tail_(0),
|
||||
running_(false),
|
||||
device_(device) {
|
||||
GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
|
||||
mask_ = capacity_ - 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Attempts to enqueue a task into the queue.
|
||||
*
|
||||
* @param item Unique pointer to the task.
|
||||
* @return true if the task was successfully enqueued, false if the queue was full.
|
||||
*/
|
||||
bool enqueue(std::unique_ptr<cann_task> && item) {
|
||||
size_t next_tail = (tail_ + 1) & mask_;
|
||||
|
||||
if (next_tail == head_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
buffer_[tail_] = std::move(item);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
tail_ = next_tail;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Submits a task to the queue, and starts the worker thread if not already running.
|
||||
*
|
||||
* @param task Task to be submitted.
|
||||
*/
|
||||
void submit_task(std::unique_ptr<cann_task> && task) {
|
||||
while (!enqueue(std::move(task))) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!running_) {
|
||||
running_ = true;
|
||||
thread_ = std::thread(&cann_task_queue::execute, this);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Waits until the queue is completely empty and no tasks are being processed.
|
||||
*/
|
||||
void wait() {
|
||||
while (running_ && head_ != tail_) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Stops the task queue and joins the worker thread.
|
||||
*/
|
||||
void stop() {
|
||||
running_ = false;
|
||||
if (thread_.joinable()) {
|
||||
thread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Worker thread function that continuously dequeues and executes tasks.
|
||||
*/
|
||||
void execute() {
|
||||
ggml_cann_set_device(device_);
|
||||
|
||||
while (running_) {
|
||||
if (head_ == tail_) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
buffer_[head_]->run_task();
|
||||
buffer_[head_].reset();
|
||||
head_ = (head_ + 1) & mask_;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<cann_task>> buffer_;
|
||||
const size_t capacity_;
|
||||
size_t mask_;
|
||||
size_t head_;
|
||||
size_t tail_;
|
||||
bool running_;
|
||||
std::thread thread_;
|
||||
int32_t device_;
|
||||
};
|
||||
|
||||
#ifdef USE_ACL_GRAPH
|
||||
struct ggml_graph_node_properties {
|
||||
// dst tensor
|
||||
@@ -350,6 +474,7 @@ struct ggml_backend_cann_context {
|
||||
ggml_cann_graph_lru_cache graph_lru_cache;
|
||||
bool acl_graph_mode = true;
|
||||
#endif
|
||||
cann_task_queue task_queue;
|
||||
bool async_mode;
|
||||
// Rope Cache
|
||||
ggml_cann_rope_cache rope_cache;
|
||||
@@ -363,10 +488,15 @@ struct ggml_backend_cann_context {
|
||||
* @brief Constructor for initializing the context with a given device.
|
||||
* @param device Device ID.
|
||||
*/
|
||||
explicit ggml_backend_cann_context(int device) : device(device), name("CANN" + std::to_string(device)) {
|
||||
explicit ggml_backend_cann_context(int device) :
|
||||
device(device),
|
||||
name("CANN" + std::to_string(device)),
|
||||
task_queue(1024, device) {
|
||||
ggml_cann_set_device(device);
|
||||
description = aclrtGetSocName();
|
||||
|
||||
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
|
||||
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF");
|
||||
#ifdef USE_ACL_GRAPH
|
||||
acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on"));
|
||||
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
|
||||
@@ -379,6 +509,7 @@ struct ggml_backend_cann_context {
|
||||
*/
|
||||
~ggml_backend_cann_context() {
|
||||
ggml_cann_set_device(device);
|
||||
task_queue.stop();
|
||||
if (copy_event != nullptr) {
|
||||
ACL_CHECK(aclrtDestroyEvent(copy_event));
|
||||
}
|
||||
|
||||
@@ -22,24 +22,24 @@
|
||||
|
||||
#include "ggml-cann.h"
|
||||
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-cann/aclnn_ops.h"
|
||||
#include "ggml-cann/common.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <acl/acl.h>
|
||||
#include <aclnnop/aclnn_trans_matmul_weight.h>
|
||||
#include <stdarg.h>
|
||||
#include <aclnnop/aclnn_trans_matmul_weight.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <chrono>
|
||||
#include <unordered_set>
|
||||
#include <optional>
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-cann/aclnn_ops.h"
|
||||
#include "ggml-cann/common.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
|
||||
@@ -1177,18 +1177,19 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
|
||||
* across calls. This reduces overhead from repeated memory allocation and deallocation.
|
||||
*/
|
||||
static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) {
|
||||
acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
|
||||
uint64_t workspaceSize = 0;
|
||||
aclTensor * weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor * executor;
|
||||
|
||||
// TransMatmulWeight
|
||||
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor));
|
||||
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, &workspaceSize, &executor));
|
||||
// Avoid frequent malloc/free of the workspace.
|
||||
g_nz_workspaces[device].realloc(workspaceSize);
|
||||
|
||||
void * g_nz_workspace = g_nz_workspaces[device].get();
|
||||
|
||||
ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
|
||||
ACL_CHECK(aclDestroyTensor(weightTransposed));
|
||||
}
|
||||
|
||||
// TODO: need handle tensor which has paddings.
|
||||
@@ -1640,7 +1641,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
|
||||
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
||||
},
|
||||
/* .device = */
|
||||
ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
|
||||
ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
|
||||
/* .context = */ nullptr,
|
||||
};
|
||||
|
||||
@@ -1776,12 +1777,6 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
|
||||
case GGML_OP_GROUP_NORM:
|
||||
ggml_cann_group_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
ggml_cann_l2_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
ggml_cann_cross_entropy_loss(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
ggml_cann_concat(ctx, dst);
|
||||
break;
|
||||
@@ -1948,8 +1943,7 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
|
||||
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
|
||||
GGML_ASSERT(!ggml_is_quantized(tensor->type));
|
||||
|
||||
ACL_CHECK(aclrtMemcpyAsync((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE,
|
||||
cann_ctx->stream()));
|
||||
ggml_cann_async_memcpy(cann_ctx, (char *) tensor->data + offset, data, size, ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1974,8 +1968,7 @@ static void ggml_backend_cann_get_tensor_async(ggml_backend_t backend,
|
||||
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
|
||||
GGML_ASSERT(!ggml_is_quantized(tensor->type));
|
||||
|
||||
ACL_CHECK(aclrtMemcpyAsync(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST,
|
||||
cann_ctx->stream()));
|
||||
ggml_cann_async_memcpy(cann_ctx, data, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -2036,6 +2029,7 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
|
||||
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
|
||||
|
||||
// wait for task_queue empty to keep task order.
|
||||
cann_ctx_src->task_queue.wait();
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
|
||||
cann_ctx_src->stream()));
|
||||
// record event on src stream after the copy
|
||||
@@ -2068,6 +2062,7 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
|
||||
*/
|
||||
static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
|
||||
ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
|
||||
cann_ctx->task_queue.wait();
|
||||
ggml_cann_set_device(cann_ctx->device);
|
||||
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
||||
}
|
||||
@@ -2246,7 +2241,8 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
||||
bool & use_cann_graph,
|
||||
bool & cann_graph_update_required) {
|
||||
#ifdef USE_ACL_GRAPH
|
||||
if (use_cann_graph && cann_graph_update_required) { // Begin CANN graph capture
|
||||
ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
|
||||
if (use_cann_graph && cann_graph_update_required) {
|
||||
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
|
||||
}
|
||||
#endif // USE_ACL_GRAPH
|
||||
@@ -2270,14 +2266,12 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
||||
}
|
||||
|
||||
#ifdef USE_ACL_GRAPH
|
||||
if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
|
||||
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
|
||||
}
|
||||
|
||||
if (use_cann_graph) {
|
||||
ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
|
||||
|
||||
if (cann_graph_update_required) { // End CANN graph capture
|
||||
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
|
||||
}
|
||||
|
||||
// Execute CANN graph
|
||||
// Execute graph
|
||||
ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream()));
|
||||
}
|
||||
#endif // USE_ACL_GRAPH
|
||||
@@ -2303,9 +2297,9 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
||||
// calculate rope cache for fist layer in current device.
|
||||
cann_ctx->rope_cache.cached = false;
|
||||
|
||||
bool cann_graph_update_required = false;
|
||||
#ifdef USE_ACL_GRAPH
|
||||
bool use_cann_graph = true;
|
||||
bool cann_graph_update_required = false;
|
||||
|
||||
static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
|
||||
if (!prefill_use_graph) {
|
||||
@@ -2336,6 +2330,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
||||
}
|
||||
#else
|
||||
bool use_cann_graph = false;
|
||||
bool cann_graph_update_required = false;
|
||||
#endif // USE_ACL_GRAPH
|
||||
evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required);
|
||||
|
||||
@@ -2484,9 +2479,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] > 896) {
|
||||
return false;
|
||||
}
|
||||
#ifdef ASCEND_310P
|
||||
if (!ggml_is_contiguous(op->src[0])) {
|
||||
return false;
|
||||
@@ -2523,11 +2515,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
// value of paddingW should be at most half of kernelW
|
||||
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
|
||||
}
|
||||
case GGML_OP_SUM:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_REPEAT:
|
||||
|
||||
@@ -126,48 +126,36 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
)
|
||||
if (NOT ARM_MCPU_RESULT)
|
||||
string(REGEX MATCH "-mcpu=[^ ']+" ARM_MCPU_FLAG "${ARM_MCPU}")
|
||||
string(REGEX MATCH "-march=[^ ']+" ARM_MARCH_FLAG "${ARM_MCPU}")
|
||||
|
||||
# on some old GCC we need to read -march=
|
||||
if (ARM_MARCH_FLAG AND NOT "${ARM_MARCH_FLAG}" STREQUAL "-march=native")
|
||||
set(ARM_NATIVE_FLAG "${ARM_MARCH_FLAG}")
|
||||
elseif(ARM_MCPU_FLAG AND NOT "${ARM_MCPU_FLAG}" STREQUAL "-mcpu=native")
|
||||
set(ARM_NATIVE_FLAG "${ARM_MCPU_FLAG}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if ("${ARM_NATIVE_FLAG}" STREQUAL "")
|
||||
set(ARM_NATIVE_FLAG -mcpu=native)
|
||||
message(WARNING "ARM -march/-mcpu not found, -mcpu=native will be used")
|
||||
else()
|
||||
message(STATUS "ARM detected flags: ${ARM_NATIVE_FLAG}")
|
||||
if ("${ARM_MCPU_FLAG}" STREQUAL "")
|
||||
set(ARM_MCPU_FLAG -mcpu=native)
|
||||
message(STATUS "ARM -mcpu not found, -mcpu=native will be used")
|
||||
endif()
|
||||
|
||||
include(CheckCXXSourceRuns)
|
||||
|
||||
macro(check_arm_feature tag feature code)
|
||||
function(check_arm_feature tag code)
|
||||
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+${tag}")
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
|
||||
check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
|
||||
if (GGML_MACHINE_SUPPORTS_${tag})
|
||||
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}")
|
||||
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
|
||||
else()
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+no${tag}")
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}")
|
||||
check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
|
||||
if (GGML_MACHINE_SUPPORTS_no${tag})
|
||||
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}")
|
||||
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_${feature})
|
||||
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
|
||||
endif()
|
||||
endif()
|
||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||
endmacro()
|
||||
endfunction()
|
||||
|
||||
check_arm_feature(dotprod DOTPROD "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
|
||||
check_arm_feature(i8mm MATMUL_INT8 "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
|
||||
check_arm_feature(sve SVE "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
|
||||
check_arm_feature(sme SME "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
|
||||
check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
|
||||
check_arm_feature(i8mm "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
|
||||
check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
|
||||
check_arm_feature(sme "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
|
||||
|
||||
list(APPEND ARCH_FLAGS "${ARM_NATIVE_FLAG}${ARM_NATIVE_FLAG_FIX}")
|
||||
list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}")
|
||||
else()
|
||||
if (GGML_CPU_ARM_ARCH)
|
||||
list(APPEND ARCH_FLAGS -march=${GGML_CPU_ARM_ARCH})
|
||||
@@ -217,27 +205,35 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message(STATUS "Checking for ARM features using flags:")
|
||||
foreach(flag IN LISTS ARCH_FLAGS)
|
||||
message(STATUS " ${flag}")
|
||||
endforeach()
|
||||
# show enabled features
|
||||
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
|
||||
set(FEAT_INPUT_FILE "NUL")
|
||||
else()
|
||||
set(FEAT_INPUT_FILE "/dev/null")
|
||||
endif()
|
||||
|
||||
include(CheckCXXSourceCompiles)
|
||||
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}")
|
||||
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
||||
set(ARM_FEATURE "HAVE_${feature}")
|
||||
check_cxx_source_compiles(
|
||||
"
|
||||
#if !defined(__ARM_FEATURE_${feature})
|
||||
# error \"Feature ${feature} is not defined\"
|
||||
#endif
|
||||
int main() { return 0; }
|
||||
"
|
||||
${ARM_FEATURE}
|
||||
)
|
||||
endforeach()
|
||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_C_COMPILER} ${ARCH_FLAGS} -dM -E -
|
||||
INPUT_FILE ${FEAT_INPUT_FILE}
|
||||
OUTPUT_VARIABLE ARM_FEATURE
|
||||
RESULT_VARIABLE ARM_FEATURE_RESULT
|
||||
)
|
||||
if (ARM_FEATURE_RESULT)
|
||||
message(WARNING "Failed to get ARM features")
|
||||
else()
|
||||
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
||||
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
|
||||
if (NOT ${feature_pos} EQUAL -1)
|
||||
# Special handling for MATMUL_INT8 when machine doesn't support i8mm
|
||||
if ("${feature}" STREQUAL "MATMUL_INT8" AND GGML_MACHINE_SUPPORTS_noi8mm)
|
||||
message(STATUS "ARM feature ${feature} detected but unsetting due to machine not supporting i8mm")
|
||||
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_MATMUL_INT8)
|
||||
else()
|
||||
message(STATUS "ARM feature ${feature} enabled")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "x86")
|
||||
message(STATUS "x86 detected")
|
||||
@@ -392,9 +388,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
||||
|
||||
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
|
||||
elseif (EXTRACTED_NUMBER EQUAL 9)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
|
||||
else()
|
||||
@@ -452,35 +448,22 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
ggml-cpu/spacemit/ime_kernels.h
|
||||
)
|
||||
endif()
|
||||
if(NOT GGML_CPU_ALL_VARIANTS)
|
||||
set(MARCH_STR "rv64gc")
|
||||
if (GGML_RV_ZFH)
|
||||
string(APPEND MARCH_STR "_zfh")
|
||||
endif()
|
||||
if (GGML_XTHEADVECTOR)
|
||||
string(APPEND MARCH_STR "_xtheadvector")
|
||||
elseif (GGML_RVV)
|
||||
string(APPEND MARCH_STR "_v")
|
||||
if (GGML_RV_ZVFH)
|
||||
string(APPEND MARCH_STR "_zvfh")
|
||||
endif()
|
||||
endif()
|
||||
if (GGML_RV_ZICBOP)
|
||||
string(APPEND MARCH_STR "_zicbop")
|
||||
endif()
|
||||
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
|
||||
else()
|
||||
# Begin with the lowest baseline
|
||||
set(ARCH_DEFINITIONS "")
|
||||
|
||||
if (GGML_INTERNAL_RVV)
|
||||
message(STATUS "RVV enabled")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_RVV)
|
||||
list(APPEND ARCH_FLAGS -march=rv64gc_v -mabi=lp64d)
|
||||
endif()
|
||||
|
||||
ggml_add_cpu_backend_features(${GGML_CPU_NAME} riscv ${ARCH_DEFINITIONS})
|
||||
set(MARCH_STR "rv64gc")
|
||||
if (GGML_RV_ZFH)
|
||||
string(APPEND MARCH_STR "_zfh")
|
||||
endif()
|
||||
if (GGML_XTHEADVECTOR)
|
||||
string(APPEND MARCH_STR "_xtheadvector")
|
||||
elseif (GGML_RVV)
|
||||
string(APPEND MARCH_STR "_v")
|
||||
if (GGML_RV_ZVFH)
|
||||
string(APPEND MARCH_STR "_zvfh")
|
||||
endif()
|
||||
endif()
|
||||
if (GGML_RV_ZICBOP)
|
||||
string(APPEND MARCH_STR "_zicbop")
|
||||
endif()
|
||||
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||
message(STATUS "s390x detected")
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
@@ -596,7 +579,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
${KLEIDIAI_SRC}/kai/ukernels/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
|
||||
|
||||
@@ -615,34 +597,23 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c)
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
|
||||
|
||||
if (NOT DOTPROD_ENABLED MATCHES -1)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c)
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
|
||||
endif()
|
||||
|
||||
if (NOT I8MM_ENABLED MATCHES -1)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
|
||||
endif()
|
||||
|
||||
if (NOT SME_ENABLED MATCHES -1)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
|
||||
|
||||
@@ -51,8 +51,10 @@
|
||||
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
||||
|
||||
@@ -2044,26 +2044,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
|
||||
}
|
||||
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
static inline svuint32_t ggml_decode_q4scales_and_mins_for_mmla(const uint32_t * vx_scales) {
|
||||
const svbool_t pg_all = svptrue_pat_b32(SV_VL4);
|
||||
const svbool_t pg_false = svpfalse_b(); // 0x0000
|
||||
const svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8); // 0x00ff
|
||||
const svbool_t pg_odd = svzip1_b32(pg_false, pg_lo_8);
|
||||
|
||||
svuint32_t vutmp_hi, vutmp_lo;
|
||||
svuint32_t vx01 = svld1_u32(pg_lo_8, vx_scales);
|
||||
vutmp_hi = svzip1_u32(vx01, vx01);
|
||||
vutmp_hi = svlsr_n_u32_m(pg_odd, vutmp_hi, 2);
|
||||
vutmp_hi = svreinterpret_u32_u64(svand_n_u64_x(pg_all, svreinterpret_u64_u32(vutmp_hi), UINT64_C(0x303030303f3f3f3f)));
|
||||
const svuint32_t vx2 = svdup_u32(vx_scales[2]);
|
||||
vutmp_lo = svlsr_u32_x(pg_all, vx2, svreinterpret_u32_s32(svindex_s32(-2, 2)));
|
||||
vutmp_lo = svand_n_u32_z(pg_odd, vutmp_lo, UINT32_C(0x0f0f0f0f));
|
||||
svuint32_t vutmp = svorr_u32_z(pg_all, vutmp_hi, vutmp_lo);
|
||||
return vutmp;
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
#ifdef __ARM_FEATURE_MATMUL_INT8
|
||||
@@ -2086,220 +2066,8 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
|
||||
uint32_t utmp[4];
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (nrc == 2) {
|
||||
svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
|
||||
|
||||
const block_q4_K * GGML_RESTRICT vx0 = vx;
|
||||
const block_q8_K * GGML_RESTRICT vy0 = vy;
|
||||
const block_q4_K * GGML_RESTRICT vx1 = (const block_q4_K *) ((const uint8_t*)vx + bx);
|
||||
const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
|
||||
|
||||
union {
|
||||
uint32_t u32[8];
|
||||
uint64_t u64[4];
|
||||
} new_utmp;
|
||||
|
||||
svfloat32_t sumf1 = svdup_n_f32(0);
|
||||
|
||||
switch (vector_length) {
|
||||
case 128:
|
||||
{
|
||||
svbool_t pg_false = svpfalse_b();
|
||||
svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8);
|
||||
svbool_t vmins_mask1= svzip1_b32(pg_lo_8, pg_false);
|
||||
svbool_t vmins_mask2 = svzip1_b32(pg_false, pg_lo_8);
|
||||
svbool_t pg128_all = svptrue_pat_b8(SV_VL16);
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
|
||||
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
|
||||
svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
|
||||
svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
|
||||
svfloat32_t vy_dmins = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
|
||||
svfloat32_t svdmins = svmul_n_f32_x(pg128_all, svmul_f32_x(pg128_all, vy_dmins, vx_dmins), -1);
|
||||
const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
|
||||
const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs;
|
||||
const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
|
||||
const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs;
|
||||
svint16_t lo = svld1_s16(pg128_all, vy0[i].bsums + 0);
|
||||
svint16_t hi = svld1_s16(pg128_all, vy0[i].bsums + 8);
|
||||
svint16_t sum_tmp1 = svuzp1_s16(lo, hi);
|
||||
svint16_t sum_tmp2 = svuzp2_s16(lo, hi);
|
||||
svint16_t svq8sums_0 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
|
||||
lo = svld1_s16(pg128_all, vy1[i].bsums + 0);
|
||||
hi = svld1_s16(pg128_all, vy1[i].bsums + 8);
|
||||
sum_tmp1 = svuzp1(lo, hi);
|
||||
sum_tmp2 = svuzp2(lo, hi);
|
||||
svint16_t svq8sums_1 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
|
||||
svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
|
||||
svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
|
||||
svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
|
||||
svst2_u32(pg128_all, new_utmp.u32, decoded_scales);
|
||||
svint16_t svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp1_u32(svld1_u32(vmins_mask1, new_utmp.u32+4), svdup_n_u32(0)))));
|
||||
svint16_t svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp2_u32(svld1_u32(vmins_mask2, new_utmp.u32+4), svdup_n_u32(0)))));
|
||||
svint32_t svsumfs_tmp1 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_0));
|
||||
svint32_t svsumfs_tmp2 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_1));
|
||||
svint32_t svsumfs_tmp3 = svtrn1_s32(svsumfs_tmp1, svsumfs_tmp2);
|
||||
svint32_t svsumfs_tmp4 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_0));
|
||||
svint32_t svsumfs_tmp5 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_1));
|
||||
svint32_t svsumfs_tmp6 = svtrn1_s32(svsumfs_tmp4, svsumfs_tmp5);
|
||||
svint32_t svsumfs_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
|
||||
svint32_t svsumfs_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
|
||||
svint32_t svsumfs_tmp = svadd_s32_x(pg128_all, svsumfs_tmp7, svsumfs_tmp8);
|
||||
svint32_t svscales, sumi1, sumi2;
|
||||
svint32_t acc_sumif1 = svdup_n_s32(0);
|
||||
svint32_t acc_sumif2 = svdup_n_s32(0);
|
||||
svint8_t q4bytes_0_l, q4bytes_0_h, q4bytes_1_l, q4bytes_1_h, l0, l1, l2, l3,
|
||||
q8bytes_0_h, q8bytes_0_l, q8bytes_1_h, q8bytes_1_l, r0, r1, r2, r3;
|
||||
#pragma GCC unroll 1
|
||||
for (int j = 0; j < QK_K/64; ++j) {
|
||||
q4bytes_0_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 0xf));
|
||||
q4bytes_1_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 0xf));
|
||||
q4bytes_0_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 0xf));
|
||||
q4bytes_1_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 0xf));
|
||||
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
|
||||
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
|
||||
l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
|
||||
l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
|
||||
q8bytes_0_h = svld1_s8(pg128_all, q8_0);
|
||||
q8bytes_1_h = svld1_s8(pg128_all, q8_1);
|
||||
q8bytes_0_l = svld1_s8(pg128_all, q8_0+16);
|
||||
q8bytes_1_l = svld1_s8(pg128_all, q8_1+16);
|
||||
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
|
||||
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
|
||||
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
|
||||
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
|
||||
sumi1 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
|
||||
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
|
||||
acc_sumif1 = svmla_s32_x(pg128_all, acc_sumif1, svscales, sumi1);
|
||||
|
||||
q4bytes_0_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 4));
|
||||
q4bytes_1_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 4));
|
||||
q4bytes_0_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 4));
|
||||
q4bytes_1_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 4));
|
||||
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
|
||||
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
|
||||
l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
|
||||
l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
|
||||
q8bytes_0_h = svld1_s8(pg128_all, q8_0+32);
|
||||
q8bytes_1_h = svld1_s8(pg128_all, q8_1+32);
|
||||
q8bytes_0_l = svld1_s8(pg128_all, q8_0+48);
|
||||
q8bytes_1_l = svld1_s8(pg128_all, q8_1+48);
|
||||
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
|
||||
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
|
||||
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
|
||||
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
|
||||
sumi2 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
|
||||
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
|
||||
acc_sumif2 = svmla_s32_x(pg128_all, acc_sumif2, svscales, sumi2);
|
||||
q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
|
||||
}
|
||||
sumf1 = svmla_f32_x(pg128_all,
|
||||
svmla_f32_x(pg128_all,
|
||||
sumf1,
|
||||
svcvt_f32_x(pg128_all,
|
||||
svadd_s32_x(pg128_all, acc_sumif1, acc_sumif2)),
|
||||
svsuper_block_scales),
|
||||
svdmins,
|
||||
svcvt_f32_s32_x(pg128_all, svsumfs_tmp));
|
||||
} //end of for nb
|
||||
} // end of case 128
|
||||
break;
|
||||
case 256:
|
||||
case 512:
|
||||
{
|
||||
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
|
||||
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
|
||||
const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
|
||||
const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs;
|
||||
const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
|
||||
const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs;
|
||||
svint32_t svscales, sumi1, sumi2;
|
||||
svint32_t acc_sumif1 = svdup_n_s32(0);
|
||||
svint32_t acc_sumif2 = svdup_n_s32(0);
|
||||
svint8_t l0, l1, l2, l3, r0, r1, r2, r3;
|
||||
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
|
||||
svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
|
||||
svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
|
||||
svfloat32_t svsuper_block_scales = svmul_f32_z(pg32_4, vy_d, vx_d);
|
||||
svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
|
||||
svfloat64_t vy_dmins_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
|
||||
svfloat32_t vy_dmins = svreinterpret_f32_f64(svuzp1_f64(vy_dmins_tmp, vy_dmins_tmp));
|
||||
svfloat32_t svdmins = svmul_n_f32_x(pg32_4, svmul_f32_x(pg32_4, vx_dmins, vy_dmins), -1);
|
||||
svint16_t rc1 = svuzp1_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
|
||||
svint16_t rc2 = svuzp2_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
|
||||
svint16_t svq8sums = svadd_s16_x(pg256_all, rc1, rc2);
|
||||
svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
|
||||
svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
|
||||
svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
|
||||
svst2_u32(pg8_16, new_utmp.u32, decoded_scales);
|
||||
svint16_t new_svq8sums_0 = svreinterpret_s16_u64(svtrn1_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
|
||||
svint16_t new_svq8sums_1 = svreinterpret_s16_u64(svtrn2_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
|
||||
svuint64_t new_mins_0 = svdup_u64(new_utmp.u64[2]);
|
||||
svuint64_t new_mins_1 = svdup_u64(new_utmp.u64[3]);
|
||||
svint16_t new_svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_0)));
|
||||
svint16_t new_svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_1)));
|
||||
svint64_t dot_prod_0 = svdot_s64(svdup_s64(0), new_svmins8_0, new_svq8sums_0);
|
||||
svint64_t dot_prod_1 = svdot_s64(dot_prod_0, new_svmins8_1, new_svq8sums_1);
|
||||
svfloat32_t converted_dot_prod_1 = svcvt_f32_s64_x(pg256_all, dot_prod_1);
|
||||
svfloat32_t svsumfs_tmp = svuzp1_f32(converted_dot_prod_1, converted_dot_prod_1);
|
||||
|
||||
#pragma GCC unroll 1
|
||||
for (int j = 0; j < QK_K/64; ++j) {
|
||||
svuint8_t q4bytes_0 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 0xf);
|
||||
svuint8_t q4bytes_1 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 0xf);
|
||||
svuint8_t q4bytes_2 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 4);
|
||||
svuint8_t q4bytes_3 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 4);
|
||||
l0 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
|
||||
l1 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
|
||||
l2 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
|
||||
l3 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
|
||||
svint8_t q8bytes_0 = svld1_s8(pg256_all, q8_0);
|
||||
svint8_t q8bytes_1 = svld1_s8(pg256_all, q8_1);
|
||||
svint8_t q8bytes_2 = svld1_s8(pg256_all, q8_0+32);
|
||||
svint8_t q8bytes_3 = svld1_s8(pg256_all, q8_1+32);
|
||||
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
||||
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
||||
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
|
||||
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
|
||||
sumi1 = svmmla(svmmla(svdup_n_s32(0), r0, l0), r1, l1);
|
||||
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
|
||||
acc_sumif1 = svmla_s32_x(pg256_all, acc_sumif1, svscales, sumi1);
|
||||
sumi2 = svmmla(svmmla(svdup_n_s32(0), r2, l2), r3, l3);
|
||||
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
|
||||
acc_sumif2 = svmla_s32_x(pg256_all, acc_sumif2, svscales, sumi2);
|
||||
q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
|
||||
}
|
||||
svint32_t acc_sumif = svadd_s32_x(pg256_all, acc_sumif1, acc_sumif2);
|
||||
svint32_t swap_acc_sumif = svext_s32(acc_sumif, acc_sumif, 4);
|
||||
acc_sumif = svadd_s32_x(pg32_4, acc_sumif, swap_acc_sumif);
|
||||
sumf1 = svmla_f32_x(pg32_4,
|
||||
svmla_f32_x(pg32_4,
|
||||
sumf1,
|
||||
svcvt_f32_x(pg32_4, acc_sumif),
|
||||
svsuper_block_scales),
|
||||
svdmins,
|
||||
svsumfs_tmp);
|
||||
} // end of for nb
|
||||
} // end of case 256-512
|
||||
break;
|
||||
default:
|
||||
assert(false && "Unsupported vector length");
|
||||
break;
|
||||
}
|
||||
|
||||
svst1_f32(pg32_2, s, sumf1);
|
||||
svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sumf1), svdup_n_u8(0), 8)));
|
||||
|
||||
return;
|
||||
}
|
||||
#elif defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (nrc == 2) {
|
||||
const block_q4_K * GGML_RESTRICT x0 = x;
|
||||
const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
|
||||
@@ -2467,6 +2235,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
const uint8_t * GGML_RESTRICT q4 = x[i].qs;
|
||||
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
||||
|
||||
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
||||
const svuint8_t m4b = svdup_n_u8(0xf);
|
||||
const svint32_t mzero = svdup_n_s32(0);
|
||||
svint32_t sumi1 = svdup_n_s32(0);
|
||||
@@ -2711,201 +2480,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (nrc == 2) {
|
||||
const svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
|
||||
|
||||
svfloat32_t sum = svdup_n_f32(0);
|
||||
|
||||
const block_q6_K * GGML_RESTRICT vx0 = vx;
|
||||
const block_q8_K * GGML_RESTRICT vy0 = vy;
|
||||
const block_q6_K * GGML_RESTRICT vx1 = (const block_q6_K *) ((const uint8_t*)vx + bx);
|
||||
const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
|
||||
|
||||
switch (vector_length) {
|
||||
case 128:
|
||||
{
|
||||
const svbool_t pg128_all = svptrue_pat_b8(SV_ALL);
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
|
||||
const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
|
||||
const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
|
||||
const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
|
||||
const int8_t * GGML_RESTRICT q80 = vy0[i].qs;
|
||||
const int8_t * GGML_RESTRICT q81 = vy1[i].qs;
|
||||
|
||||
const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
|
||||
const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
|
||||
|
||||
svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
|
||||
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
|
||||
svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
|
||||
// process q8sum summation 128 bit route
|
||||
const svint16_t q8sums_01 = svld1_s16(pg128_all, vy0[i].bsums);
|
||||
const svint16_t q8sums_02 = svld1_s16(pg128_all, vy0[i].bsums + 8);
|
||||
const svint16_t q8sums_11 = svld1_s16(pg128_all, vy1[i].bsums);
|
||||
const svint16_t q8sums_12 = svld1_s16(pg128_all, vy1[i].bsums + 8);
|
||||
const svint64x2_t q6scales_0_tmp = svld2_s64(pg128_all, (const int64_t *)scale0);
|
||||
const svint16_t q6scales_01 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 0)));
|
||||
const svint16_t q6scales_02 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 1)));
|
||||
const svint64x2_t q6scales_1_tmp = svld2_s64(pg128_all, (const int64_t *)scale1);
|
||||
const svint16_t q6scales_11 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 0)));
|
||||
const svint16_t q6scales_12 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 1)));
|
||||
const svint64_t prod = svdup_n_s64(0);
|
||||
|
||||
svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_01), q8sums_02, q6scales_02));
|
||||
svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_11), q8sums_02, q6scales_12));
|
||||
svint32_t isum_tmp3 = svtrn1_s32(isum_tmp1, isum_tmp2);
|
||||
svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_01), q8sums_12, q6scales_02));
|
||||
svint32_t isum_tmp5 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_11), q8sums_12, q6scales_12));
|
||||
svint32_t isum_tmp6 = svtrn1_s32(isum_tmp4, isum_tmp5);
|
||||
svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
|
||||
svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
|
||||
svint32_t svisum_mins = svadd_s32_x(pg128_all, isum_tmp7, isum_tmp8);
|
||||
|
||||
// process mmla
|
||||
svint8_t l0, l1, r0, r1;
|
||||
svint32_t isum_tmp = svdup_n_s32(0);
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
svuint8_t qhbits_0 = svld1_u8(pg128_all, qh0+16*(k%2));
|
||||
svuint8_t qhbits_1 = svld1_u8(pg128_all, qh1+16*(k%2));
|
||||
svuint8_t q6bits_0 = svld1_u8(pg128_all, ql0+16*(k%4));
|
||||
svuint8_t q6bits_1 = svld1_u8(pg128_all, ql1+16*(k%4));
|
||||
const int ql_pos = (k/4)*4;
|
||||
svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_0, 4);
|
||||
svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_1, 4);
|
||||
const int qh_pos = (k/2)*2;
|
||||
svuint8_t q6bytes_0_hi = svand_n_u8_x(pg128_all, qhbits_0, 0x3 << qh_pos);
|
||||
svuint8_t q6bytes_1_hi = svand_n_u8_x(pg128_all, qhbits_1, 0x3 << qh_pos);
|
||||
svint8_t q6bytes_0, q6bytes_1;
|
||||
if (qh_pos <= 4) {
|
||||
q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
|
||||
q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
|
||||
} else {
|
||||
q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_0_lo, svlsr_n_u8_x(pg128_all, q6bytes_0_hi, (qh_pos - 4))));
|
||||
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_1_lo, svlsr_n_u8_x(pg128_all, q6bytes_1_hi, (qh_pos - 4))));
|
||||
}
|
||||
svint8_t q8bytes_0 = svld1_s8(pg128_all, q80+16*(k%8));
|
||||
svint8_t q8bytes_1 = svld1_s8(pg128_all, q81+16*(k%8));
|
||||
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
|
||||
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
|
||||
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
||||
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
||||
svint32_t svscale = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
|
||||
isum_tmp = svmla_s32_x(pg128_all, isum_tmp, svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), svscale);
|
||||
}
|
||||
qh0 += 32; qh1 += 32;
|
||||
ql0 += 64; ql1 += 64;
|
||||
q80 += 128; q81 += 128;
|
||||
scale0 += 8; scale1 += 8;
|
||||
}
|
||||
sum = svmla_f32_x(pg128_all, sum,
|
||||
svcvt_f32_x(pg128_all, svmla_s32_x(pg128_all, isum_tmp,
|
||||
svisum_mins, svdup_n_s32(-32))),
|
||||
svsuper_block_scales);
|
||||
}
|
||||
} // end of case 128
|
||||
break;
|
||||
case 256:
|
||||
case 512:
|
||||
{
|
||||
const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
|
||||
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
|
||||
const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
|
||||
const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
|
||||
const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
|
||||
const int8_t * GGML_RESTRICT q80 = vy0[i].qs;
|
||||
const int8_t * GGML_RESTRICT q81 = vy1[i].qs;
|
||||
|
||||
const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
|
||||
const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
|
||||
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
|
||||
svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
|
||||
svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
|
||||
svfloat32_t svsuper_block_scales = svmul_f32_x(pg32_4, vy_d, vx_d);
|
||||
// process q8sum summation 256 bit route
|
||||
const svint16_t q8sums_0 = svld1_s16(pg256_all, vy0[i].bsums);
|
||||
const svint16_t q8sums_1 = svld1_s16(pg256_all, vy1[i].bsums);
|
||||
const svint16_t q6scales_0 = svunpklo_s16(svld1_s8(pg256_all, scale0));
|
||||
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(pg256_all, scale1));
|
||||
const svint64_t prod = svdup_n_s64(0);
|
||||
svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_0));
|
||||
svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_1));
|
||||
svint32_t isum_tmp3 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_0));
|
||||
svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_1));
|
||||
svint32_t isum_tmp5 = svtrn1_s32(isum_tmp1, isum_tmp2);
|
||||
svint32_t isum_tmp6 = svtrn1_s32(isum_tmp3, isum_tmp4);
|
||||
svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
|
||||
svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
|
||||
svint32_t isum_tmp9 = svadd_s32_x(pg256_all, isum_tmp7, isum_tmp8);
|
||||
svint32_t isum_tmp10 = svreinterpret_s32_u8(svext_u8(svreinterpret_u8_s32(isum_tmp9), svreinterpret_u8_s32(isum_tmp9), 16));
|
||||
svint32_t svisum_mins = svadd_s32_z(pg32_4, isum_tmp9, isum_tmp10);
|
||||
|
||||
// process mmla
|
||||
svint8_t l0, l1, r0, r1;
|
||||
svint32_t isum_tmp = svdup_n_s32(0);
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
for (int k = 0; k < 8; k+=2) { // process 2 block
|
||||
svuint8_t qhbits_0 = svld1_u8(pg256_all, qh0);
|
||||
svuint8_t qhbits_1 = svld1_u8(pg256_all, qh1);
|
||||
svuint8_t q6bits_0 = svld1_u8(pg256_all, ql0+32*((k%4)/2));
|
||||
svuint8_t q6bits_1 = svld1_u8(pg256_all, ql1+32*((k%4)/2));
|
||||
const int ql_pos = (k/4)*4;
|
||||
svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_0, 4);
|
||||
svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_1, 4);
|
||||
const int qh_pos = (k/2)*2;
|
||||
svuint8_t q6bytes_0_hi = svand_n_u8_x(pg256_all, qhbits_0, 0x3 << qh_pos);
|
||||
svuint8_t q6bytes_1_hi = svand_n_u8_x(pg256_all, qhbits_1, 0x3 << qh_pos);
|
||||
svint8_t q6bytes_0, q6bytes_1;
|
||||
if (qh_pos <= 4) {
|
||||
q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
|
||||
q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
|
||||
} else {
|
||||
q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_0_lo, svlsr_n_u8_x(pg256_all, q6bytes_0_hi, (qh_pos - 4))));
|
||||
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_1_lo, svlsr_n_u8_x(pg256_all, q6bytes_1_hi, (qh_pos - 4))));
|
||||
}
|
||||
svint8_t q8bytes_0 = svld1_s8(pg256_all, q80+32*(k/2));
|
||||
svint8_t q8bytes_1 = svld1_s8(pg256_all, q81+32*(k/2));
|
||||
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
|
||||
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
|
||||
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
||||
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
||||
svint32_t svscale0 = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
|
||||
svint32_t svscale1 = svzip1_s32(svdup_n_s32(scale0[k+1]), svdup_n_s32(scale1[k+1]));
|
||||
isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r0, l0), svscale0);
|
||||
isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r1, l1), svscale1);
|
||||
}
|
||||
qh0 += 32; qh1 += 32;
|
||||
ql0 += 64; ql1 += 64;
|
||||
q80 += 128; q81 += 128;
|
||||
scale0 += 8; scale1 += 8;
|
||||
} // end of for
|
||||
svint32_t swap_isum_tmp = svext_s32(isum_tmp, isum_tmp, 4);
|
||||
isum_tmp = svadd_s32_x(pg32_4, isum_tmp, swap_isum_tmp);
|
||||
sum = svmla_f32_x(pg32_4, sum,
|
||||
svcvt_f32_x(pg32_4, svmla_s32_x(pg32_4, isum_tmp,
|
||||
svisum_mins, svdup_n_s32(-32))),
|
||||
svsuper_block_scales);
|
||||
}
|
||||
} // end of case 256
|
||||
break;
|
||||
default:
|
||||
assert(false && "Unsupported vector length");
|
||||
break;
|
||||
} // end of switch
|
||||
|
||||
svst1_f32(pg32_2, s, sum);
|
||||
svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sum), svdup_n_u8(0), 8)));
|
||||
|
||||
return;
|
||||
}
|
||||
#elif defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (nrc == 2) {
|
||||
const block_q6_K * GGML_RESTRICT x0 = x;
|
||||
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
|
||||
@@ -3019,6 +2594,27 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
// adjust bias, apply superblock scale
|
||||
{
|
||||
int32_t bias[4];
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
||||
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
|
||||
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
|
||||
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
|
||||
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
|
||||
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
|
||||
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
|
||||
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
|
||||
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
|
||||
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
|
||||
const svint64_t zero = svdup_n_s64(0);
|
||||
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
|
||||
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
|
||||
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
|
||||
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
|
||||
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
|
||||
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
|
||||
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
|
||||
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
|
||||
#else
|
||||
// NEON doesn't support int16 dot product, fallback to separated mul and add
|
||||
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
|
||||
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
|
||||
@@ -3050,6 +2646,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
|
||||
bias[3] = vaddvq_s32(prod);
|
||||
|
||||
#endif
|
||||
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
|
||||
|
||||
const float32x4_t superblock_scale = {
|
||||
@@ -3075,6 +2672,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
#endif
|
||||
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
||||
float sum = 0;
|
||||
svuint8_t m4b = svdup_n_u8(0xf);
|
||||
svint32_t vzero = svdup_n_s32(0);
|
||||
|
||||
@@ -24,29 +24,6 @@
|
||||
|
||||
#define UNUSED GGML_UNUSED
|
||||
|
||||
static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
|
||||
int16x8_t * out_mins,
|
||||
int8_t * out_scales) {
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
constexpr uint8_t scales_size = 12;
|
||||
|
||||
uint32_t sm[3];
|
||||
memcpy(sm, scales_in, scales_size);
|
||||
|
||||
const uint32_t mins_0_3 = sm[1] & kmask1;
|
||||
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
|
||||
const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
|
||||
|
||||
*out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));
|
||||
|
||||
uint32_t scales_u32[2];
|
||||
scales_u32[0] = sm[0] & kmask1;
|
||||
scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
|
||||
memcpy(out_scales, scales_u32, 8);
|
||||
}
|
||||
|
||||
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(QK8_0 == 32);
|
||||
assert(k % QK8_0 == 0);
|
||||
@@ -497,162 +474,6 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON)
|
||||
constexpr int col_pairs = ncols_interleaved / 2;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
|
||||
// 1x8 tile = 2 x 4
|
||||
float32x4_t acc_f32[ncols_interleaved / 4];
|
||||
|
||||
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < ncols_interleaved / 4; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
|
||||
float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
||||
float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d);
|
||||
float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d);
|
||||
float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
|
||||
float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
|
||||
float32x4_t sb_min_0 = vmulq_f32(q4_dmin_0, q8_d);
|
||||
float32x4_t sb_min_1 = vmulq_f32(q4_dmin_1, q8_d);
|
||||
|
||||
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
|
||||
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
||||
// 2 sb each iteration
|
||||
int32x4_t acc_lo[col_pairs];
|
||||
int32x4_t acc_hi[col_pairs];
|
||||
|
||||
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
||||
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
||||
int16_t bsums_arr[8];
|
||||
vst1q_s16(bsums_arr, bsums);
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
for (int i = 0; i < col_pairs; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
|
||||
int16x8_t q4sb_scales[2];
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q4sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||
}
|
||||
|
||||
const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
|
||||
|
||||
// Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
|
||||
// but still need the qs to use the low and hi bits from q4
|
||||
const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
|
||||
int8x16_t q8_qs[8];
|
||||
for (int i = 0; i < 8; i++) {
|
||||
q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
|
||||
}
|
||||
|
||||
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
||||
for (int cp = 0; cp < col_pairs; cp++) {
|
||||
uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp);
|
||||
uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64);
|
||||
uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128);
|
||||
uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192);
|
||||
|
||||
acc_lo[cp] =
|
||||
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]); // 0 .. 7
|
||||
acc_lo[cp] =
|
||||
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]); // 8 ..15
|
||||
acc_lo[cp] =
|
||||
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]); // 16..23
|
||||
acc_lo[cp] =
|
||||
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]); // 24..31
|
||||
|
||||
acc_hi[cp] =
|
||||
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]); // 32..39
|
||||
acc_hi[cp] =
|
||||
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]); // 40..47
|
||||
acc_hi[cp] =
|
||||
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]); // 48..55
|
||||
acc_hi[cp] =
|
||||
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]); // 56..63
|
||||
}
|
||||
|
||||
// Iterates over a pair of column pairs (4 columns) to use a single 128 register
|
||||
// p = 0 -> 0123 p2 -> 4567
|
||||
for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
|
||||
int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]);
|
||||
int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]);
|
||||
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
|
||||
|
||||
// 0123 or 4567
|
||||
// TODO: Single superblock mul at the end of the superblock
|
||||
float32x4_t sumf_0 =
|
||||
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
|
||||
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
|
||||
|
||||
float32x4_t sumf_1 =
|
||||
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
|
||||
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
|
||||
}
|
||||
|
||||
// Multiply Acc bsum + mins
|
||||
// Each pair of subblocks share the same bsums
|
||||
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
|
||||
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
||||
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
||||
|
||||
// cols 0-3 bias
|
||||
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
||||
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
||||
|
||||
// cols 4-7 bias
|
||||
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
||||
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
||||
} // for sb
|
||||
|
||||
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0);
|
||||
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1);
|
||||
} // for b
|
||||
|
||||
int base = x * ncols_interleaved;
|
||||
vst1q_f32(s + base, acc_f32[0]);
|
||||
vst1q_f32(s + base + 4, acc_f32[1]);
|
||||
} // for x
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON)
|
||||
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
@@ -2068,212 +1889,3 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
|
||||
// 8 accumulators: 2 row pairs × 4 col pairs
|
||||
float32x4_t acc_f32[blocklen];
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// bsums pairs belongs to the same q8_k subblock
|
||||
const int16x8_t bsums[4]{
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
int16_t bsums_arr[4][8];
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
||||
}
|
||||
|
||||
int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
|
||||
int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
|
||||
int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
|
||||
for (int i = 0; i < 8; i++) {
|
||||
acc[i] = vdupq_n_s32(0);
|
||||
bias_acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int8_t q4sb_scales[2][8];
|
||||
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
|
||||
for (int i = 0; i < 2; i++) {
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
|
||||
}
|
||||
|
||||
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
|
||||
const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
|
||||
|
||||
int8x16_t q8_qs_01[8];
|
||||
int8x16_t q8_qs_23[8];
|
||||
|
||||
// Load 32-byte per row pair, 1 subblock each time
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const int offset = i * 32; // 16 for row 01, 16 for row 23
|
||||
q8_qs_01[i] = vld1q_s8(q8_base + offset);
|
||||
q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
|
||||
}
|
||||
|
||||
const int8x16_t q8s[2][8] = {
|
||||
{ q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
|
||||
q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
|
||||
{ q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
|
||||
q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
|
||||
};
|
||||
|
||||
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
||||
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
sb_acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
|
||||
uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
|
||||
uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
|
||||
uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
|
||||
const int8x16_t q4_nibbles[2][4] = {
|
||||
{
|
||||
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
|
||||
},
|
||||
{
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
|
||||
}
|
||||
};
|
||||
|
||||
// Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
|
||||
// for each of the internal 32 qs subblock (blk)
|
||||
for (int rp = 0; rp < 2; rp++) {
|
||||
for (int blk = 0; blk < 2; blk++) {
|
||||
const int8x16_t * q8 = &q8s[rp][4 * blk];
|
||||
const int8x16_t * q4 = q4_nibbles[blk];
|
||||
int32x4_t acc = sb_acc[2 * rp + blk];
|
||||
// mul add for each qs in the same subblock
|
||||
for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
|
||||
acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
|
||||
}
|
||||
sb_acc[2 * rp + blk] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
// Scales[i] corresponds to column i
|
||||
const int scale_offset = cp * 2;
|
||||
for (int blk = 0; blk < 2; blk++) {
|
||||
const int32x4_t block_scale = {
|
||||
(int32_t) q4sb_scales[blk][scale_offset],
|
||||
(int32_t) q4sb_scales[blk][scale_offset],
|
||||
(int32_t) q4sb_scales[blk][scale_offset + 1],
|
||||
(int32_t) q4sb_scales[blk][scale_offset + 1],
|
||||
};
|
||||
acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
|
||||
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
|
||||
}
|
||||
}
|
||||
|
||||
// Multiply Acc bsum + mins
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
// Each pair of subblocks share the same bsums
|
||||
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
|
||||
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
|
||||
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
|
||||
|
||||
bias_acc[2 * q8_row] =
|
||||
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
||||
bias_acc[2 * q8_row] =
|
||||
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
||||
bias_acc[2 * q8_row + 1] =
|
||||
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
||||
bias_acc[2 * q8_row + 1] =
|
||||
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
||||
}
|
||||
} // for sb
|
||||
|
||||
// Reorder of i8mm output with bias and output layout
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
|
||||
acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
|
||||
}
|
||||
int32x4_t reorder_acc[8] = {
|
||||
vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
|
||||
vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
|
||||
vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
|
||||
vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
|
||||
vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
|
||||
vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
|
||||
vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
|
||||
vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
|
||||
};
|
||||
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
|
||||
float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
|
||||
const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
|
||||
|
||||
float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
|
||||
const float32x4_t scale = vmulq_f32(q4_d, q8_d);
|
||||
|
||||
acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
|
||||
acc_f32[2 * i + j] =
|
||||
vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
|
||||
}
|
||||
}
|
||||
} // for b
|
||||
|
||||
// With the previous reorder, the tile is already in the correct memory layout.
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#if defined(__riscv) && __riscv_xlen == 64
|
||||
#include <sys/auxv.h>
|
||||
|
||||
//https://github.com/torvalds/linux/blob/master/arch/riscv/include/uapi/asm/hwcap.h#L24
|
||||
#ifndef COMPAT_HWCAP_ISA_V
|
||||
#define COMPAT_HWCAP_ISA_V (1 << ('V' - 'A'))
|
||||
#endif
|
||||
|
||||
struct riscv64_features {
|
||||
bool has_rvv = false;
|
||||
|
||||
riscv64_features() {
|
||||
uint32_t hwcap = getauxval(AT_HWCAP);
|
||||
|
||||
has_rvv = !!(hwcap & COMPAT_HWCAP_ISA_V);
|
||||
}
|
||||
};
|
||||
|
||||
static int ggml_backend_cpu_riscv64_score() {
|
||||
int score = 1;
|
||||
riscv64_features rf;
|
||||
|
||||
#ifdef GGML_USE_RVV
|
||||
if (!rf.has_rvv) { return 0; }
|
||||
score += 1 << 1;
|
||||
#endif
|
||||
|
||||
return score;
|
||||
}
|
||||
|
||||
GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_riscv64_score)
|
||||
|
||||
#endif // __riscv && __riscv_xlen == 64
|
||||
@@ -646,7 +646,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
||||
int64_t xstart = 0;
|
||||
int anr = nr - nr%16; // Used to align nr with boundary of 16
|
||||
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
#ifdef __AVX512F__
|
||||
int anc = nc - nc%16; // Used to align nc with boundary of 16
|
||||
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
||||
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
||||
@@ -1041,7 +1041,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
xstart = anc/8;
|
||||
y = 0;
|
||||
}
|
||||
#endif // __AVX512BW__ && __AVX512DQ__
|
||||
#endif // __AVX512F__
|
||||
|
||||
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
||||
|
||||
@@ -1989,7 +1989,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
||||
int64_t xstart = 0;
|
||||
int anr = nr - nr % 16;; // Used to align nr with boundary of 16
|
||||
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
#ifdef __AVX512F__
|
||||
int anc = nc - nc % 16; // Used to align nc with boundary of 16
|
||||
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
||||
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
||||
@@ -2727,7 +2727,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
xstart = anc/8;
|
||||
y = 0;
|
||||
}
|
||||
#endif // __AVX512BW__ && __AVX512DQ__
|
||||
#endif //AVX512F
|
||||
|
||||
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
||||
for (; y < anr / 4; y += 4) {
|
||||
@@ -3467,7 +3467,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
__m256i scalesmask2 = _mm256_castsi128_si256(scalesmask2_sse);
|
||||
scalesmask2 = _mm256_permute2f128_si256(scalesmask2, scalesmask2, 0);
|
||||
|
||||
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
#ifdef __AVX512F__
|
||||
|
||||
int anc = nc - nc % 16; // Used to align nc with boundary of 16
|
||||
|
||||
@@ -4947,7 +4947,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
y = 0;
|
||||
}
|
||||
|
||||
#endif // __AVX512BW__ && __AVX512DQ__
|
||||
#endif //AVX512F
|
||||
|
||||
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
||||
for (; y < anr / 4; y += 4) {
|
||||
|
||||
@@ -1731,10 +1731,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_sum_rows(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_CUMSUM:
|
||||
{
|
||||
ggml_compute_forward_cumsum(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
ggml_compute_forward_mean(params, tensor);
|
||||
@@ -1811,6 +1807,22 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_cont(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_RESHAPE:
|
||||
{
|
||||
ggml_compute_forward_reshape(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_VIEW:
|
||||
{
|
||||
ggml_compute_forward_view(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_PERMUTE:
|
||||
{
|
||||
ggml_compute_forward_permute(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_TRANSPOSE:
|
||||
{
|
||||
ggml_compute_forward_transpose(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
ggml_compute_forward_get_rows(params, tensor);
|
||||
@@ -1931,14 +1943,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_leaky_relu(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_TRI:
|
||||
{
|
||||
ggml_compute_forward_tri(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_FILL:
|
||||
{
|
||||
ggml_compute_forward_fill(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
ggml_compute_forward_flash_attn_ext(params, tensor);
|
||||
@@ -1994,10 +1998,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
ggml_compute_forward_solve_tri(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MAP_CUSTOM1:
|
||||
{
|
||||
ggml_compute_forward_map_custom1(params, tensor);
|
||||
@@ -2042,22 +2042,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
// nop
|
||||
} break;
|
||||
case GGML_OP_RESHAPE:
|
||||
{
|
||||
// nop
|
||||
} break;
|
||||
case GGML_OP_PERMUTE:
|
||||
{
|
||||
// nop
|
||||
} break;
|
||||
case GGML_OP_VIEW:
|
||||
{
|
||||
// nop
|
||||
} break;
|
||||
case GGML_OP_TRANSPOSE:
|
||||
{
|
||||
// nop
|
||||
} break;
|
||||
case GGML_OP_COUNT:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -2156,9 +2140,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_ADD_ID:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_FILL:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
@@ -2176,7 +2157,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
@@ -2199,8 +2179,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
@@ -2906,11 +2884,6 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
|
||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||
|
||||
if (ggml_op_is_empty(node->op)) {
|
||||
// skip NOPs
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_compute_forward(¶ms, node);
|
||||
|
||||
if (state->ith == 0 && cplan->abort_callback &&
|
||||
@@ -3296,13 +3269,6 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) {
|
||||
__m128 y_vec = _mm_cvtph_ps(x_vec);
|
||||
_mm_storeu_ps(y + i, y_vec);
|
||||
}
|
||||
#elif defined(__riscv_zvfh)
|
||||
for (int vl; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m1(n - i);
|
||||
vfloat16m1_t vx = __riscv_vle16_v_f16m1((_Float16 *)&x[i], vl);
|
||||
vfloat32m2_t vy = __riscv_vfwcvt_f_f_v_f32m2(vx, vl);
|
||||
__riscv_vse32_v_f32m2(&y[i], vy, vl);
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; i < n; ++i) {
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
|
||||
// KleidiAI micro-kernels
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
||||
#include "kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
||||
@@ -12,34 +11,23 @@
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
||||
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
|
||||
#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
|
||||
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
|
||||
#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
|
||||
|
||||
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
||||
#include "kai_lhs_quant_pack_qai8dxp_f32.h"
|
||||
|
||||
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
|
||||
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
||||
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
||||
#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
|
||||
|
||||
#include "kai_common.h"
|
||||
|
||||
#include "simd-mappings.h"
|
||||
|
||||
#define GGML_COMMON_DECL_CPP
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
#define NELEMS(x) (sizeof(x) / sizeof(*x))
|
||||
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t)>
|
||||
static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
|
||||
@@ -67,14 +55,6 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
||||
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
|
||||
static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
||||
const void* lhs, const void* rhs, void* dst,
|
||||
size_t dst_stride_row, size_t dst_stride_col,
|
||||
float clamp_min, float clamp_max) {
|
||||
Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
|
||||
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
|
||||
return Fn(m, k, bl, mr, kr, sr);
|
||||
@@ -113,12 +93,6 @@ static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t m
|
||||
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
|
||||
static inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
|
||||
size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) {
|
||||
Fn(m, k, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
||||
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
|
||||
return Fn(n, k, nr, kr, bl);
|
||||
@@ -150,18 +124,6 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n
|
||||
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
|
||||
static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
||||
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
|
||||
void* rhs_packed, size_t extra_bytes, const void* params) {
|
||||
Fn(num_groups, n, k, nr, kr, sr,
|
||||
static_cast<const int8_t*>(rhs),
|
||||
static_cast<const float*>(bias),
|
||||
static_cast<const float*>(scale),
|
||||
rhs_packed, extra_bytes,
|
||||
static_cast<const kai_rhs_pack_qsi8cx_params*>(params));
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
|
||||
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
||||
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
|
||||
@@ -251,57 +213,6 @@ static void dequantize_row_qsi4c32ps1s0scalef16(
|
||||
GGML_UNUSED(kr);
|
||||
}
|
||||
|
||||
static void dequantize_row_qsi8cxp(
|
||||
const void *packed_data,
|
||||
int32_t row_idx,
|
||||
int64_t k,
|
||||
float *out,
|
||||
size_t nr,
|
||||
size_t packed_row_stride,
|
||||
size_t kr,
|
||||
size_t bl,
|
||||
size_t num_bytes_multiplier
|
||||
) {
|
||||
GGML_UNUSED(bl);
|
||||
GGML_UNUSED(num_bytes_multiplier);
|
||||
|
||||
const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0;
|
||||
const size_t group_idx = row_idx / nr;
|
||||
const size_t row_in_group = row_idx % nr;
|
||||
|
||||
const uint8_t * group_ptr = static_cast<const uint8_t *>(packed_data) + group_idx * packed_row_stride;
|
||||
const int8_t * data_base = reinterpret_cast<const int8_t *>(group_ptr);
|
||||
|
||||
const size_t num_blocks = k_internal / kr;
|
||||
|
||||
for (size_t block = 0; block < num_blocks; ++block) {
|
||||
const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr;
|
||||
for (size_t i = 0; i < kr; ++i) {
|
||||
const size_t k_idx = block * kr + i;
|
||||
if (k_idx < (size_t) k) {
|
||||
out[k_idx] = static_cast<float>(block_ptr[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const uint8_t * sums_ptr = group_ptr + nr * k_internal;
|
||||
GGML_UNUSED(sums_ptr);
|
||||
|
||||
const float * scale_ptr = reinterpret_cast<const float *>(sums_ptr + nr * sizeof(int32_t));
|
||||
const float scale = scale_ptr[row_in_group];
|
||||
|
||||
if (scale == 0.0f) {
|
||||
for (size_t i = 0; i < (size_t) k; ++i) {
|
||||
out[i] = 0.0f;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < (size_t) k; ++i) {
|
||||
out[i] *= scale;
|
||||
}
|
||||
}
|
||||
|
||||
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
#if defined(__ARM_FEATURE_SME)
|
||||
{
|
||||
@@ -635,176 +546,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
},
|
||||
#endif
|
||||
#endif
|
||||
{ /* Sentinel */ }
|
||||
};
|
||||
|
||||
static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
||||
#if defined(__ARM_FEATURE_SME)
|
||||
{
|
||||
/* SME GEMM */
|
||||
{
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
||||
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||
},
|
||||
/* SME GEMV */
|
||||
{
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
||||
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
||||
/* .to_float = */ dequantize_row_qsi8cxp,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
{
|
||||
/* I8MM GEMM */
|
||||
{
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
||||
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||
},
|
||||
/* I8MM GEMV (dotprod fallback) */
|
||||
{
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
||||
/* .to_float = */ dequantize_row_qsi8cxp,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
{
|
||||
/* DOTPROD GEMM */
|
||||
{
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||
},
|
||||
/* DOTPROD GEMV */
|
||||
{
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
||||
/* .to_float = */ dequantize_row_qsi8cxp,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
{ /* Sentinel */ }
|
||||
};
|
||||
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
||||
@@ -812,7 +553,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
|
||||
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
|
||||
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
|
||||
@@ -821,21 +562,6 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!kernel) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
|
||||
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
|
||||
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
|
||||
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
|
||||
gemm_gemv_kernels_q8[i].op_type == tensor->type) {
|
||||
kernel = &gemm_gemv_kernels_q8[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(gemm_gemv_kernels);
|
||||
GGML_UNUSED(gemm_gemv_kernels_q8);
|
||||
GGML_UNUSED(cpu_features);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -846,31 +572,12 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
|
||||
ggml_kleidiai_kernels * kernels = nullptr;
|
||||
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
|
||||
kernels = &gemm_gemv_kernels[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(features);
|
||||
#endif
|
||||
|
||||
return kernels;
|
||||
}
|
||||
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) {
|
||||
ggml_kleidiai_kernels * kernels = nullptr;
|
||||
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
|
||||
if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
|
||||
kernels = &gemm_gemv_kernels_q8[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(features);
|
||||
#endif
|
||||
|
||||
return kernels;
|
||||
|
||||
@@ -87,4 +87,3 @@ struct ggml_kleidiai_kernels {
|
||||
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features);
|
||||
|
||||
@@ -5,13 +5,10 @@
|
||||
#include <assert.h>
|
||||
#include <atomic>
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#if defined(__linux__)
|
||||
#include <asm/hwcap.h>
|
||||
#include <sys/auxv.h>
|
||||
@@ -41,9 +38,8 @@
|
||||
|
||||
struct ggml_kleidiai_context {
|
||||
cpu_feature features;
|
||||
ggml_kleidiai_kernels * kernels_q4;
|
||||
ggml_kleidiai_kernels * kernels_q8;
|
||||
} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
|
||||
ggml_kleidiai_kernels * kernels;
|
||||
} static ctx = { CPU_FEATURE_NONE, NULL };
|
||||
|
||||
static const char* cpu_feature_to_string(cpu_feature f) {
|
||||
switch (f) {
|
||||
@@ -77,14 +73,10 @@ static void init_kleidiai_context(void) {
|
||||
if (sme_enabled != 0) {
|
||||
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
||||
}
|
||||
ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
||||
ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
|
||||
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
||||
#ifndef NDEBUG
|
||||
if (ctx.kernels_q4) {
|
||||
GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
|
||||
}
|
||||
if (ctx.kernels_q8) {
|
||||
GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
|
||||
if (ctx.kernels) {
|
||||
GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -138,9 +130,6 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
||||
if (!lhs_info->packed_size_ex) return false;
|
||||
size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
|
||||
} else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
|
||||
if (!lhs_info->packed_size_ex) return false;
|
||||
size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
|
||||
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
||||
if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
|
||||
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
|
||||
@@ -160,13 +149,11 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
if (dst->op == GGML_OP_MUL_MAT) {
|
||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||
return compute_forward_q4_0(params, dst);
|
||||
} else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
|
||||
return compute_forward_q8_0(params, dst);
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
return compute_forward_fp16(params, dst);
|
||||
}
|
||||
} else if (dst->op == GGML_OP_GET_ROWS) {
|
||||
if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
|
||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||
return compute_forward_get_rows(params, dst);
|
||||
}
|
||||
}
|
||||
@@ -413,120 +400,19 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||
if (!kernels) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_gemv = src1->ne[1] == 1;
|
||||
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
||||
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
||||
|
||||
if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
|
||||
!kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth_raw = params->nth;
|
||||
const int nth = nth_raw > 0 ? nth_raw : 1;
|
||||
|
||||
const size_t k = ne00;
|
||||
const size_t m = ne11;
|
||||
const size_t n = ne01;
|
||||
|
||||
size_t mr = kernel->get_mr();
|
||||
size_t kr = kernel->get_kr();
|
||||
size_t sr = kernel->get_sr();
|
||||
|
||||
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
|
||||
uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
|
||||
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
|
||||
|
||||
const size_t n_step = kernel->get_n_step();
|
||||
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
|
||||
const size_t n_start = ith * num_n_per_thread;
|
||||
|
||||
size_t n_to_process = 0;
|
||||
if (n_start < n) {
|
||||
n_to_process = num_n_per_thread;
|
||||
if ((n_start + n_to_process) > n) {
|
||||
n_to_process = n - n_start;
|
||||
}
|
||||
}
|
||||
|
||||
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
|
||||
const size_t m_start = ith * num_m_per_thread;
|
||||
size_t m_to_process = num_m_per_thread;
|
||||
if ((m_start + m_to_process) > m) {
|
||||
m_to_process = m - m_start;
|
||||
}
|
||||
|
||||
if (m_start < m) {
|
||||
const size_t src_stride = src1->nb[1];
|
||||
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
|
||||
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
||||
|
||||
lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
const size_t dst_stride = dst->nb[1];
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
|
||||
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
|
||||
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
||||
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
||||
const void * lhs_ptr = static_cast<const void *>(lhs_packed + lhs_packed_offset);
|
||||
float * dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
||||
|
||||
if (n_to_process > 0) {
|
||||
kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
|
||||
sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
||||
if (!ctx.kernels) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
ggml_kleidiai_kernels * kernels = nullptr;
|
||||
size_t block_len = 0;
|
||||
size_t num_bytes_multiplier = 0;
|
||||
|
||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||
if (!ctx.kernels_q4) {
|
||||
return false;
|
||||
}
|
||||
kernels = ctx.kernels_q4;
|
||||
block_len = QK4_0;
|
||||
num_bytes_multiplier = sizeof(uint16_t);
|
||||
} else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
|
||||
if (!ctx.kernels_q8) {
|
||||
return false;
|
||||
}
|
||||
kernels = ctx.kernels_q8;
|
||||
block_len = QK8_0;
|
||||
num_bytes_multiplier = sizeof(float);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
rhs_packing_info * rhs_info = &kernels->rhs_info;
|
||||
kernel_info * kernel = &kernels->gemm;
|
||||
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
|
||||
kernel_info * kernel = &ctx.kernels->gemm;
|
||||
if (!rhs_info->to_float || !kernel->get_nr) {
|
||||
return false;
|
||||
}
|
||||
@@ -537,7 +423,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
const size_t block_rows = kernel->get_nr();
|
||||
const size_t kr = kernel->get_kr();
|
||||
|
||||
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);
|
||||
const size_t num_bytes_multiplier = sizeof(uint16_t);
|
||||
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
@@ -552,7 +439,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
|
||||
|
||||
float *out = (float *)((char *)dst->data + i * nb1);
|
||||
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
|
||||
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -560,91 +447,21 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
|
||||
public:
|
||||
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
||||
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
const size_t n = tensor->ne[1];
|
||||
const size_t k = tensor->ne[0];
|
||||
size_t nr = ctx.kernels->gemm.get_nr();
|
||||
size_t kr = ctx.kernels->gemm.get_kr();
|
||||
size_t sr = ctx.kernels->gemm.get_sr();
|
||||
|
||||
if (tensor->type == GGML_TYPE_Q4_0) {
|
||||
if (!ctx.kernels_q4) {
|
||||
return -1;
|
||||
}
|
||||
size_t nr = ctx.kernels_q4->gemm.get_nr();
|
||||
size_t kr = ctx.kernels_q4->gemm.get_kr();
|
||||
size_t sr = ctx.kernels_q4->gemm.get_sr();
|
||||
|
||||
struct kai_rhs_pack_qs4cxs1s0_param params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
|
||||
static_cast<const uint8_t *>(data),
|
||||
nullptr, nullptr, tensor->data, 0, ¶ms);
|
||||
GGML_UNUSED(data_size);
|
||||
return 0;
|
||||
} else if (tensor->type == GGML_TYPE_Q8_0) {
|
||||
if (!ctx.kernels_q8) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
const size_t row_stride = tensor->nb[1];
|
||||
const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
|
||||
|
||||
std::vector<int8_t> qdata(n * k, 0);
|
||||
std::vector<float> scales(n, 0.0f);
|
||||
|
||||
for (size_t row = 0; row < n; ++row) {
|
||||
const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
|
||||
static_cast<const uint8_t *>(data) + row * row_stride);
|
||||
|
||||
float max_abs = 0.0f;
|
||||
for (size_t block = 0; block < k_blocks; ++block) {
|
||||
const block_q8_0 & blk = row_blocks[block];
|
||||
const float d = GGML_FP16_TO_FP32(blk.d);
|
||||
for (size_t l = 0; l < QK8_0; ++l) {
|
||||
const size_t linear_idx = block * QK8_0 + l;
|
||||
if (linear_idx >= k) {
|
||||
break;
|
||||
}
|
||||
const float value = d * blk.qs[l];
|
||||
max_abs = std::max(max_abs, std::fabs(value));
|
||||
}
|
||||
}
|
||||
|
||||
float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;
|
||||
scales[row] = scale;
|
||||
const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;
|
||||
|
||||
for (size_t block = 0; block < k_blocks; ++block) {
|
||||
const block_q8_0 & blk = row_blocks[block];
|
||||
const float d = GGML_FP16_TO_FP32(blk.d);
|
||||
for (size_t l = 0; l < QK8_0; ++l) {
|
||||
const size_t linear_idx = block * QK8_0 + l;
|
||||
if (linear_idx >= k) {
|
||||
break;
|
||||
}
|
||||
const float value = d * blk.qs[l];
|
||||
int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
|
||||
q = std::clamp(q, -127, 127);
|
||||
qdata[row * k + linear_idx] = static_cast<int8_t>(q);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t nr = ctx.kernels_q8->gemm.get_nr();
|
||||
size_t kr = ctx.kernels_q8->gemm.get_kr();
|
||||
size_t sr = ctx.kernels_q8->gemm.get_sr();
|
||||
|
||||
struct kai_rhs_pack_qsi8cx_params params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.scale_multiplier = 1.0f;
|
||||
|
||||
ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
|
||||
qdata.data(), nullptr, scales.data(),
|
||||
tensor->data, 0, ¶ms);
|
||||
GGML_UNUSED(data_size);
|
||||
return 0;
|
||||
}
|
||||
struct kai_rhs_pack_qs4cxs1s0_param params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, ¶ms);
|
||||
|
||||
return 0;
|
||||
GGML_UNUSED(data_size);
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -701,45 +518,27 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
|
||||
}
|
||||
|
||||
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
|
||||
const size_t n = tensor->ne[1];
|
||||
const size_t k = tensor->ne[0];
|
||||
const size_t nr = ctx.kernels->gemm.get_nr();
|
||||
const size_t kr = ctx.kernels->gemm.get_kr();
|
||||
|
||||
return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0);
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
|
||||
const size_t n = tensor->ne[1];
|
||||
const size_t k = tensor->ne[0];
|
||||
|
||||
ggml_kleidiai_kernels * kernels = nullptr;
|
||||
size_t block_len = 0;
|
||||
|
||||
if (tensor->type == GGML_TYPE_Q4_0) {
|
||||
GGML_ASSERT(ctx.kernels_q4);
|
||||
kernels = ctx.kernels_q4;
|
||||
block_len = QK4_0;
|
||||
} else if (tensor->type == GGML_TYPE_Q8_0) {
|
||||
GGML_ASSERT(ctx.kernels_q8);
|
||||
kernels = ctx.kernels_q8;
|
||||
block_len = QK8_0;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const size_t nr = kernels->gemm.get_nr();
|
||||
const size_t kr = kernels->gemm.get_kr();
|
||||
const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
|
||||
const size_t raw = ggml_nbytes(tensor);
|
||||
|
||||
return packed > raw ? packed : raw;
|
||||
}
|
||||
|
||||
namespace ggml::cpu::kleidiai {
|
||||
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
||||
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
|
||||
(op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
|
||||
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
||||
op->src[0]->buffer &&
|
||||
(ggml_n_dims(op->src[0]) == 2) &&
|
||||
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||
if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
|
||||
return false;
|
||||
}
|
||||
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
||||
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -7,10 +7,8 @@
|
||||
#include "unary-ops.h"
|
||||
#include "vec.h"
|
||||
|
||||
#include <cfloat>
|
||||
#include <float.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
|
||||
// ggml_compute_forward_dup
|
||||
|
||||
@@ -1396,56 +1394,6 @@ void ggml_compute_forward_sum(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_cumsum
|
||||
|
||||
static void ggml_compute_forward_cumsum_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(ne0 == ne00);
|
||||
GGML_ASSERT(ne1 == ne01);
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
GGML_ASSERT(ne3 == ne03);
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
ggml_vec_cumsum_f32(ne00, dst_row, src_row);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_cumsum(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_cumsum_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_sum_rows
|
||||
|
||||
static void ggml_compute_forward_sum_rows_f32(
|
||||
@@ -2192,83 +2140,6 @@ static void ggml_compute_forward_gelu(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_fill
|
||||
|
||||
static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const float c = ggml_get_op_params_f32(dst, 0);
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, dst);
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne2*ne1);
|
||||
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
|
||||
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
|
||||
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
||||
|
||||
ggml_vec_set_f32(ne0, dst_ptr, c);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
ggml_compute_forward_fill_f32(params, dst);
|
||||
}
|
||||
|
||||
// ggml_compute_tri
|
||||
|
||||
static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
|
||||
bool (*bipred)(int, int);
|
||||
|
||||
switch (ttype) {
|
||||
case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
|
||||
case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
|
||||
case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
|
||||
case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
|
||||
default: GGML_ABORT("invalid tri type");
|
||||
}
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
||||
|
||||
for (int i0 = 0; i0 < ne0; ++i0) {
|
||||
dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_tri_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_gelu_erf
|
||||
|
||||
static void ggml_compute_forward_gelu_erf_f32(
|
||||
@@ -4584,6 +4455,46 @@ void ggml_compute_forward_cont(
|
||||
ggml_compute_forward_dup(params, dst);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_reshape
|
||||
|
||||
void ggml_compute_forward_reshape(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
// NOP
|
||||
GGML_UNUSED(params);
|
||||
GGML_UNUSED(dst);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_view
|
||||
|
||||
void ggml_compute_forward_view(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
// NOP
|
||||
GGML_UNUSED(params);
|
||||
GGML_UNUSED(dst);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_permute
|
||||
|
||||
void ggml_compute_forward_permute(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
// NOP
|
||||
GGML_UNUSED(params);
|
||||
GGML_UNUSED(dst);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_transpose
|
||||
|
||||
void ggml_compute_forward_transpose(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
// NOP
|
||||
GGML_UNUSED(params);
|
||||
GGML_UNUSED(dst);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_get_rows
|
||||
|
||||
static void ggml_compute_forward_get_rows_q(
|
||||
@@ -5632,28 +5543,7 @@ static void ggml_mrope_cache_init(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
|
||||
for (int64_t i0 = 0; i0 < n; i0 += 2) {
|
||||
const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const T * const src = src_data + ic;
|
||||
T * dst = dst_data + ic;
|
||||
|
||||
const float x0 = type_conversion_table<T>::to_f32(src[0]);
|
||||
const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
|
||||
|
||||
dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
|
||||
dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T> //float or ggml_fp16_t
|
||||
static void ggml_compute_forward_rope_flt(
|
||||
static void ggml_compute_forward_rope_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst,
|
||||
const bool forward) {
|
||||
@@ -5662,9 +5552,6 @@ static void ggml_compute_forward_rope_flt(
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
int sections[4];
|
||||
|
||||
@@ -5687,8 +5574,7 @@ static void ggml_compute_forward_rope_flt(
|
||||
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
||||
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
||||
|
||||
GGML_ASSERT(nb0 == nb00);
|
||||
GGML_ASSERT(nb0 == sizeof(T));
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
@@ -5713,11 +5599,12 @@ static void ggml_compute_forward_rope_flt(
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
||||
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (mrope_used) {
|
||||
if (is_mrope) {
|
||||
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
||||
}
|
||||
|
||||
@@ -5743,7 +5630,7 @@ static void ggml_compute_forward_rope_flt(
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (!mrope_used) {
|
||||
if (!is_mrope) {
|
||||
const int64_t p = pos[i2];
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
@@ -5761,36 +5648,269 @@ static void ggml_compute_forward_rope_flt(
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
||||
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
||||
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
||||
if (is_neox || is_mrope) {
|
||||
if (is_vision){
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
switch (mode) {
|
||||
case GGML_ROPE_TYPE_NORMAL:
|
||||
rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
|
||||
break;
|
||||
case GGML_ROPE_TYPE_NEOX:
|
||||
case GGML_ROPE_TYPE_MROPE:
|
||||
case GGML_ROPE_TYPE_IMROPE:
|
||||
rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
|
||||
break;
|
||||
case GGML_ROPE_TYPE_VISION:
|
||||
rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("rope type not supported");
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[1];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_vision) {
|
||||
if (is_vision) {
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
} else {
|
||||
// fill the remain channels with data from src tensor
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
} //attn-heads
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: deduplicate f16/f32 code
|
||||
static void ggml_compute_forward_rope_f16(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst,
|
||||
const bool forward) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
int sections[4];
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
||||
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
||||
|
||||
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nr = ggml_nrows(dst);
|
||||
|
||||
GGML_ASSERT(n_dims <= ne0);
|
||||
GGML_ASSERT(n_dims % 2 == 0);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
// row index used to determine which thread to use
|
||||
int ir = 0;
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims == ne0/2);
|
||||
}
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (src2 != NULL) {
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
|
||||
// backward process uses inverse rotation by cos and sin.
|
||||
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
||||
// this essentially just switches the sign of sin.
|
||||
const float sin_sign = forward ? 1.0f : -1.0f;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (!is_mrope) {
|
||||
const int64_t p = pos[i2];
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
else {
|
||||
const int64_t p_t = pos[i2];
|
||||
const int64_t p_h = pos[i2 + ne2];
|
||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||
ggml_mrope_cache_init(
|
||||
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
||||
if (is_neox || is_mrope) {
|
||||
if (is_vision) {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
||||
|
||||
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
|
||||
|
||||
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
|
||||
|
||||
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
||||
|
||||
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5804,11 +5924,11 @@ void ggml_compute_forward_rope(
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
|
||||
ggml_compute_forward_rope_f16(params, dst, true);
|
||||
} break;
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rope_flt<float>(params, dst, true);
|
||||
ggml_compute_forward_rope_f32(params, dst, true);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
@@ -5828,11 +5948,11 @@ void ggml_compute_forward_rope_back(
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
|
||||
ggml_compute_forward_rope_f16(params, dst, false);
|
||||
} break;
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rope_flt<float>(params, dst, false);
|
||||
ggml_compute_forward_rope_f32(params, dst, false);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
@@ -7793,18 +7913,6 @@ void ggml_compute_forward_timestep_embedding(
|
||||
|
||||
// ggml_compute_forward_argsort
|
||||
|
||||
template<enum ggml_sort_order order>
|
||||
struct argsort_cmp {
|
||||
const float * data;
|
||||
bool operator()(int32_t a, int32_t b) const {
|
||||
if constexpr (order == GGML_SORT_ORDER_ASC) {
|
||||
return data[a] < data[b];
|
||||
} else {
|
||||
return data[a] > data[b];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static void ggml_compute_forward_argsort_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
@@ -7823,25 +7931,23 @@ static void ggml_compute_forward_argsort_f32(
|
||||
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
for (int64_t i = ith; i < nr; i += nth) {
|
||||
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
||||
|
||||
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
||||
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
||||
|
||||
for (int64_t j = 0; j < ne0; j++) {
|
||||
dst_data[j] = j;
|
||||
}
|
||||
|
||||
switch (order) {
|
||||
case GGML_SORT_ORDER_ASC:
|
||||
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
|
||||
break;
|
||||
|
||||
case GGML_SORT_ORDER_DESC:
|
||||
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
|
||||
break;
|
||||
|
||||
default:
|
||||
GGML_ABORT("invalid sort order");
|
||||
// C doesn't have a functional sort, so we do a bubble sort instead
|
||||
for (int64_t j = 0; j < ne0; j++) {
|
||||
for (int64_t k = j + 1; k < ne0; k++) {
|
||||
if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
|
||||
(order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
|
||||
int32_t tmp = dst_data[j];
|
||||
dst_data[j] = dst_data[k];
|
||||
dst_data[k] = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8664,7 +8770,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
// n_head
|
||||
for (int h = ih0; h < ih1; ++h) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
||||
const float dt_soft_plus = ggml_softplus(dt[h]);
|
||||
const float dA = expf(dt_soft_plus * A[h]);
|
||||
const int g = h / (nh / ng); // repeat_interleave
|
||||
|
||||
@@ -8761,7 +8867,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
// n_head
|
||||
for (int h = ih0; h < ih1; ++h) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
||||
const float dt_soft_plus = ggml_softplus(dt[h]);
|
||||
const int g = h / (nh / ng); // repeat_interleave
|
||||
|
||||
// dim
|
||||
@@ -9044,14 +9150,6 @@ void ggml_compute_forward_unary(
|
||||
{
|
||||
ggml_compute_forward_xielu(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
{
|
||||
ggml_compute_forward_expm1(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
{
|
||||
ggml_compute_forward_softplus(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -9648,75 +9746,6 @@ void ggml_compute_forward_gla(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ne00 == ne01); // A must be square
|
||||
GGML_ASSERT(ne0 == ne10); // solution cols == B cols
|
||||
GGML_ASSERT(ne1 == ne11); // solution rows == B rows
|
||||
|
||||
GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
|
||||
GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t k = ne10; // number of RHS columns
|
||||
const int64_t n = ne11; // A is n×n
|
||||
const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
|
||||
|
||||
// chunks per thread
|
||||
const int64_t dr = (nr + nth - 1)/nth;
|
||||
|
||||
// chunk range for this thread
|
||||
const int64_t ir0 = dr*ith;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
const float * A = (const float *) src0->data; // [n, n, B1, B2]
|
||||
const float * B = (const float *) src1->data; // [n, k, B1, B2]
|
||||
float * X = ( float *) dst->data; // [n, k, B1, B2]
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*k);
|
||||
const int64_t i02 = (ir - i03*ne02*k)/k;
|
||||
const int64_t i01 = (ir - i03*ne02*k - i02*k);
|
||||
|
||||
const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
|
||||
const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
|
||||
|
||||
float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
|
||||
|
||||
for (int64_t i00 = 0; i00 < n; ++i00) {
|
||||
float sum = 0.0f;
|
||||
for (int64_t t = 0; t < i00; ++t) {
|
||||
sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
|
||||
}
|
||||
|
||||
const float diag = A_batch[i00 * n + i00];
|
||||
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
|
||||
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_compute_forward_solve_tri_f32(params, dst);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_rwkv_wkv7
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv7_f32(
|
||||
|
||||
@@ -34,7 +34,6 @@ void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct
|
||||
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cumsum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
@@ -52,6 +51,10 @@ void ggml_compute_forward_scale(const struct ggml_compute_params * params, struc
|
||||
void ggml_compute_forward_set(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cpy(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cont(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_reshape(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_view(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_permute(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
@@ -82,8 +85,6 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
|
||||
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_flash_attn_back(
|
||||
const struct ggml_compute_params * params,
|
||||
@@ -99,7 +100,6 @@ void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params,
|
||||
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
@@ -1600,52 +1600,29 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
return false;
|
||||
}
|
||||
|
||||
void forward_mul_mat_one_chunk(ggml_compute_params * params,
|
||||
ggml_tensor * op,
|
||||
int64_t src0_start,
|
||||
int64_t src0_end,
|
||||
int64_t src1_start,
|
||||
int64_t src1_end) {
|
||||
void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
|
||||
const ggml_tensor * src0 = op->src[0];
|
||||
const ggml_tensor * src1 = op->src[1];
|
||||
ggml_tensor * dst = op;
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const void * src1_wdata = params->wdata;
|
||||
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
|
||||
|
||||
GGML_ASSERT(ne03 == 1 && ne13 == 1);
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
const int64_t r2 = ne12 / ne02;
|
||||
|
||||
const int64_t i12 = src1_start / ne1;
|
||||
const int64_t i11 = src1_start - i12 * ne1;
|
||||
|
||||
// Determine batch index
|
||||
const int64_t i02 = i12 / r2;
|
||||
|
||||
const int64_t i1 = i11;
|
||||
const int64_t i2 = i12;
|
||||
|
||||
const char * src0_ptr = (const char *) src0->data + i02 * nb02;
|
||||
const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
|
||||
char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
|
||||
|
||||
const int64_t nrows = src1_end - src1_start;
|
||||
const int64_t ncols = src0_end - src0_start;
|
||||
|
||||
GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
|
||||
|
||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||
if (nrows > 3) {
|
||||
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
|
||||
src0_ptr + src0_start * nb01, src1_ptr,
|
||||
nrows - (nrows % 4), ncols);
|
||||
if (ne11 > 3) {
|
||||
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||
(float *) ((char *) dst->data) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
||||
}
|
||||
for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
|
||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
|
||||
ne01, src0_ptr + src0_start * nb01,
|
||||
src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
|
||||
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||
src0_end - src0_start);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1670,12 +1647,6 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
// TODO: General batched mul mat for 4D tensors
|
||||
// Currently only supports 3D tensors
|
||||
GGML_ASSERT(ne03 == 1);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
GGML_ASSERT(ne3 == 1);
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
|
||||
@@ -1683,64 +1654,47 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
|
||||
char * wdata = static_cast<char *>(params->wdata);
|
||||
const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
|
||||
const size_t nbw2 = nbw1 * ne11;
|
||||
|
||||
assert(params->wsize >= nbw2 * ne12);
|
||||
assert(params->wsize >= nbw1 * ne11);
|
||||
|
||||
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
|
||||
|
||||
// INFO: Quantization is done in planes to avoid extra complexity in chunking.
|
||||
// Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how
|
||||
// the planes are broadcast.
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
char * data_ptr = (char *) src1->data + i12 * nb12;
|
||||
char * wdata_ptr = wdata + i12 * nbw2;
|
||||
int64_t i11_processed = 0;
|
||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
|
||||
}
|
||||
|
||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
|
||||
(void *) (wdata_ptr + i11 * nbw1), 4, ne10);
|
||||
}
|
||||
|
||||
const int64_t i11_processed = ne11 - ne11 % 4;
|
||||
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
||||
from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
|
||||
}
|
||||
i11_processed = ne11 - ne11 % 4;
|
||||
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
||||
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
||||
}
|
||||
|
||||
// disable for NUMA
|
||||
const bool disable_chunking = ggml_is_numa();
|
||||
|
||||
// 4x chunks per thread
|
||||
const int64_t nr0 = ggml_nrows(op->src[0]);
|
||||
|
||||
int nth_scaled = nth * 4;
|
||||
int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
|
||||
int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
|
||||
|
||||
// src1 is chunked only by full planes.
|
||||
// When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE
|
||||
// to route them thorugh GEMV.
|
||||
// nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors
|
||||
// to avoid affecting their performance
|
||||
int64_t nchunk1 = ne12;
|
||||
int64_t nr = ggml_nrows(op->src[0]);
|
||||
int nth_scaled = nth * 4;
|
||||
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
||||
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
||||
|
||||
// Ensure minimum chunk size to avoid alignment issues with high thread counts
|
||||
// Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
|
||||
const int64_t min_chunk_size = NB_COLS;
|
||||
if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
|
||||
nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
||||
if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) {
|
||||
nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
|
||||
}
|
||||
|
||||
if (nth == 1 || nchunk0 < nth || disable_chunking) {
|
||||
nchunk0 = nth;
|
||||
if (nth == 1 || nchunk < nth || disable_chunking) {
|
||||
nchunk = nth;
|
||||
}
|
||||
|
||||
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
||||
|
||||
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
|
||||
// This prevents creating too many tiny chunks that could overlap after alignment
|
||||
const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
||||
nchunk0 = MIN(nchunk0, max_nchunk);
|
||||
const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
|
||||
if (nchunk > max_nchunk) {
|
||||
nchunk = max_nchunk;
|
||||
}
|
||||
|
||||
if (ith == 0) {
|
||||
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||
@@ -1752,30 +1706,23 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
||||
while (current_chunk < nchunk0 * nchunk1) {
|
||||
const int64_t ith0 = current_chunk % nchunk0;
|
||||
const int64_t ith1 = current_chunk / nchunk0;
|
||||
|
||||
int64_t src0_start = dr0 * ith0;
|
||||
int64_t src0_end = MIN(src0_start + dr0, nr0);
|
||||
|
||||
// full-plane range for src1
|
||||
int64_t src1_start = ith1 * ne11;
|
||||
int64_t src1_end = (ith1 + 1) * ne11;
|
||||
while (current_chunk < nchunk) {
|
||||
int64_t src0_start = (current_chunk * ne01) / nchunk;
|
||||
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
|
||||
|
||||
// Align boundaries to NB_COLS - round up to ensure all data is included
|
||||
// The chunk size limiting above ensures chunks are large enough to prevent overlaps
|
||||
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
||||
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
||||
src0_end = MIN(src0_end, ne01);
|
||||
|
||||
// Make sure current plane is the last one before exiting
|
||||
if (src0_start >= src0_end) {
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
continue;
|
||||
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
||||
if (src0_end > ne01) {
|
||||
src0_end = ne01;
|
||||
}
|
||||
|
||||
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
|
||||
if (src0_start >= src0_end) {
|
||||
break;
|
||||
}
|
||||
|
||||
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end);
|
||||
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
}
|
||||
@@ -1961,11 +1908,6 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
return &q4_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q4_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_Q2_K) {
|
||||
if (ggml_cpu_has_avx512()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
|
||||
@@ -160,18 +160,18 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||
#define GGML_F32xt svfloat32_t
|
||||
#define GGML_F32xt_ZERO svdup_n_f32(0.0f)
|
||||
#define GGML_F32xt_SET1(x) svdup_n_f32(x)
|
||||
#define GGML_F32xt_LOAD_IMPL(pg, a) svld1_f32(pg, a)
|
||||
#define GGML_F32xt_LOAD(a) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, a)
|
||||
#define GGML_F32xt_STORE_IMPL(pg, a, b) svst1_f32(pg, a, b)
|
||||
#define GGML_F32xt_STORE(a, b) GGML_F32xt_STORE_IMPL(DEFAULT_PG, a, b)
|
||||
#define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a)
|
||||
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
|
||||
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
|
||||
#define GGML_F32xt_FMA(a, b, c) GGML_F32xt_FMA_IMPL(DEFAULT_PG, a, b, c)
|
||||
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
|
||||
#define GGML_F32xt_ADD(a, b) GGML_F32xt_ADD_IMPL(DEFAULT_PG, a, b)
|
||||
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b)
|
||||
#define GGML_F32xt_MUL(a, b) GGML_F32xt_MUL_IMPL(DEFAULT_PG, a, b)
|
||||
#define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a)
|
||||
#define GGML_F32xt_REDUCE_ONE(a) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, a)
|
||||
#define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
|
||||
{ \
|
||||
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \
|
||||
@@ -183,8 +183,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \
|
||||
(res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \
|
||||
}
|
||||
#define GGML_F32xt_REDUCE(res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
|
||||
GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8)
|
||||
#define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
|
||||
#define GGML_F32_VEC GGML_F32xt
|
||||
#define GGML_F32_VEC_ZERO GGML_F32xt_ZERO
|
||||
@@ -207,11 +206,11 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||
#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
|
||||
|
||||
#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
|
||||
#define GGML_F32Cxt_FMA(a, b, c) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, a, b, c)
|
||||
#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
|
||||
#define GGML_F32Cxt_ADD(a, b) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, a, b)
|
||||
#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
|
||||
#define GGML_F32Cxt_MUL(a, b) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, a, b)
|
||||
#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
|
||||
|
||||
#define GGML_F16x_VEC GGML_F32Cxt
|
||||
@@ -225,7 +224,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||
#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
|
||||
|
||||
#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
|
||||
#define GGML_F16xt_REDUCE_ONE(a) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, a)
|
||||
#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
|
||||
#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
|
||||
{ \
|
||||
@@ -235,8 +234,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||
__fp16 sum_f16 = svaddv_f16(pg16, sum1); \
|
||||
(res) = (ggml_float) sum_f16; \
|
||||
}
|
||||
#define GGML_F16xt_REDUCE_MIXED(res, sum1, sum2, sum3, sum4) \
|
||||
GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, res, sum1, sum2, sum3, sum4)
|
||||
#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
|
||||
// F16 NEON
|
||||
|
||||
|
||||
@@ -73,14 +73,6 @@ static inline float op_log(float x) {
|
||||
return logf(x);
|
||||
}
|
||||
|
||||
static inline float op_expm1(float x) {
|
||||
return expf(x) - 1.0f;
|
||||
}
|
||||
|
||||
static inline float op_softplus(float x) {
|
||||
return (x > 20.0f) ? x : logf(1.0f + expf(x));
|
||||
}
|
||||
|
||||
static inline float op_floor(float x) {
|
||||
return floorf(x);
|
||||
}
|
||||
@@ -298,14 +290,6 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
|
||||
unary_op<op_log>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_expm1(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_expm1>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_softplus>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_floor>(params, dst);
|
||||
}
|
||||
|
||||
@@ -22,8 +22,6 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
|
||||
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_expm1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_softplus(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
@@ -360,13 +360,6 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
||||
for (; i + 3 < n; i += 4) {
|
||||
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
for (int vl; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e32m2(n - i);
|
||||
vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
|
||||
vfloat32m2_t vy = ggml_v_silu_m2(vx, vl);
|
||||
__riscv_vse32_v_f32m2(&y[i], vy, vl);
|
||||
}
|
||||
#endif
|
||||
for (; i < n; ++i) {
|
||||
y[i] = ggml_silu_f32(x[i]);
|
||||
@@ -467,16 +460,6 @@ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const floa
|
||||
val = vec_mul(val, val);
|
||||
sum += (ggml_float)vec_hsum_f32x4(val);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
|
||||
for (int vl; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e32m2(n - i);
|
||||
vfloat32m2_t val = __riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], vl), mean, vl);
|
||||
__riscv_vse32_v_f32m2(&y[i], val, vl);
|
||||
val = __riscv_vfmul_vv_f32m2(val, val, vl);
|
||||
vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, vl);
|
||||
}
|
||||
sum = (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
|
||||
#endif
|
||||
for (; i < n; ++i) {
|
||||
float val = x[i] - mean;
|
||||
|
||||
@@ -698,61 +698,60 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||
}
|
||||
|
||||
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
|
||||
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
|
||||
const int sve_register_length = svcntb() * 8;
|
||||
const int ggml_f16_epr = sve_register_length / 16;
|
||||
const int ggml_f16_step = 2 * ggml_f16_epr;
|
||||
#if defined(GGML_SIMD)
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
const int sve_register_length = svcntb() * 8;
|
||||
const int ggml_f16_epr = sve_register_length / 16;
|
||||
const int ggml_f16_step = 2 * ggml_f16_epr;
|
||||
|
||||
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
||||
const int np = (n & ~(ggml_f16_step - 1));
|
||||
svfloat16_t ay1, ay2;
|
||||
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
||||
const int np = (n & ~(ggml_f16_step - 1));
|
||||
svfloat16_t ay1, ay2;
|
||||
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_MUL(ay1, vx);
|
||||
GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_MUL(ay1, vx);
|
||||
GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
|
||||
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_MUL(ay2, vx);
|
||||
GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
|
||||
}
|
||||
// leftovers
|
||||
// maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
|
||||
if (np < n) {
|
||||
svbool_t pg = svwhilelt_b16(np, n);
|
||||
svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
|
||||
svfloat16_t out = svmul_f16_m(pg, hy, vx);
|
||||
svst1_f16(pg, (__fp16 *)(y + np), out);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
|
||||
for (int i = 0, vl; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m2(n - i);
|
||||
vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl);
|
||||
vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl);
|
||||
vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl);
|
||||
vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl);
|
||||
__riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl);
|
||||
}
|
||||
#elif defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_MUL(ay2, vx);
|
||||
GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
|
||||
}
|
||||
}
|
||||
// leftovers
|
||||
// maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
|
||||
if (np < n) {
|
||||
svbool_t pg = svwhilelt_b16(np, n);
|
||||
svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
|
||||
svfloat16_t out = svmul_f16_m(pg, hy, vx);
|
||||
svst1_f16(pg, (__fp16 *)(y + np), out);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
// todo: RVV impl
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#else
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
@@ -1417,16 +1416,6 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (i == 0) {
|
||||
y[i] = x[i];
|
||||
} else {
|
||||
y[i] = y[i - 1] + x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
|
||||
ggml_float sum = 0.0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
|
||||
@@ -124,7 +124,6 @@ if (CUDAToolkit_FOUND)
|
||||
|
||||
if (GGML_CUDA_DEBUG)
|
||||
list(APPEND CUDA_FLAGS -lineinfo)
|
||||
add_compile_definitions(GGML_CUDA_DEBUG)
|
||||
endif()
|
||||
|
||||
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
|
||||
|
||||
@@ -224,10 +224,6 @@ static const char * cu_get_error_str(CUresult err) {
|
||||
#define AMD_MFMA_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
|
||||
|
||||
#if defined(GGML_USE_HIP) && defined(RDNA4)
|
||||
#define AMD_WMMA_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) && defined(RDNA4)
|
||||
|
||||
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
|
||||
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#define VOLTA_MMA_AVAILABLE
|
||||
@@ -287,10 +283,6 @@ static bool amd_mfma_available(const int cc) {
|
||||
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
|
||||
}
|
||||
|
||||
static bool amd_wmma_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_RDNA4(cc);
|
||||
}
|
||||
|
||||
static bool volta_mma_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
|
||||
}
|
||||
@@ -594,12 +586,6 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v,
|
||||
// If dst and src point at different address spaces then they are guaranteed to not be aliased.
|
||||
template <int nbytes, int alignment = 0>
|
||||
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
|
||||
static_assert(
|
||||
nbytes <= ggml_cuda_get_max_cpy_bytes() || alignment == 0,
|
||||
"You are misusing the alignment parameter for ggml_cuda_memcpy_1. "
|
||||
"The intent is for the parameter is only as a workaround if either one of the pointers is not properly aligned. "
|
||||
"If you use it to do more bytes per copy than ggml_cuda_max_cpy_bytes() the reads and writes may not be coalesced. "
|
||||
"Call ggml_cuda_memcpy_1 in a loop instead.");
|
||||
if constexpr (alignment != 0) {
|
||||
static_assert(nbytes % alignment == 0, "bad alignment");
|
||||
}
|
||||
|
||||
@@ -39,15 +39,6 @@ template<typename dst_t, typename src_t>
|
||||
return __float2bfloat16(float(x));
|
||||
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
|
||||
return __bfloat162float(x);
|
||||
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
|
||||
return __float22half2_rn(x);
|
||||
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
|
||||
// bypass compile error on cuda 12.0.1
|
||||
#ifdef GGML_USE_HIP
|
||||
return __float22bfloat162_rn(x);
|
||||
#else
|
||||
return {x.x, x.y};
|
||||
#endif // GGML_USE_HIP
|
||||
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
|
||||
return int32_t(x);
|
||||
} else {
|
||||
|
||||
@@ -212,6 +212,6 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) {
|
||||
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
|
||||
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
|
||||
}
|
||||
|
||||
@@ -12,10 +12,10 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
|
||||
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
|
||||
|
||||
template <cpy_kernel_t cpy_1>
|
||||
static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
@@ -40,7 +40,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne,
|
||||
static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
@@ -166,7 +166,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
||||
static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
@@ -180,17 +180,17 @@ static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static void ggml_cpy_scalar_contiguous_cuda(
|
||||
static void ggml_cpy_flt_contiguous_cuda(
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t, bool transposed = false>
|
||||
static void ggml_cpy_scalar_cuda(
|
||||
static void ggml_cpy_flt_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
@@ -198,7 +198,7 @@ static void ggml_cpy_scalar_cuda(
|
||||
if (transposed) {
|
||||
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
|
||||
int ne00n, ne01n, ne02n;
|
||||
if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
|
||||
if (nb00 < nb02) {
|
||||
ne00n = ne00;
|
||||
ne01n = ne01;
|
||||
ne02n = ne02;
|
||||
@@ -206,17 +206,19 @@ static void ggml_cpy_scalar_cuda(
|
||||
ne00n = ne00;
|
||||
ne01n = ne01*ne02;
|
||||
ne02n = 1;
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
|
||||
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
|
||||
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
|
||||
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
||||
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
|
||||
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
|
||||
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
} else {
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
}
|
||||
@@ -384,8 +386,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
char * src1_ddc = (char *) src1->data;
|
||||
|
||||
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
||||
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
|
||||
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
|
||||
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1;
|
||||
|
||||
if (src0->type == src1->type && contiguous_srcs) {
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
@@ -399,132 +400,94 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_scalar_cuda<float, float, true>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
ggml_cpy_scalar_cuda<float, float>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_scalar_contiguous_cuda<float, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_scalar_cuda<float, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_scalar_contiguous_cuda<float, half>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_flt_contiguous_cuda<float, half> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_scalar_cuda<float, half>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
ggml_cpy_f32_q8_0_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q8_0_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
ggml_cpy_f32_q4_0_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_0_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
ggml_cpy_f32_q4_1_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_1_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
ggml_cpy_f32_q5_0_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_0_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
ggml_cpy_f32_iq4_nl_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
ggml_cpy_f32_q5_1_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_1_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_scalar_cuda<half, half, true>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
ggml_cpy_scalar_cuda<half, half>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_scalar_contiguous_cuda<half, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_scalar_cuda<half, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_scalar_contiguous_cuda<half, float>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_flt_contiguous_cuda<half, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_scalar_cuda<half, float>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, half>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_scalar_cuda<nv_bfloat16, half>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, float>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_scalar_cuda<nv_bfloat16, float>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_scalar_cuda<int32_t, int32_t, true>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
ggml_cpy_scalar_cuda<int32_t, int32_t>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_scalar_contiguous_cuda<float, int32_t>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_flt_contiguous_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_scalar_cuda<float, int32_t>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_scalar_contiguous_cuda<int32_t, float>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_flt_contiguous_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_scalar_cuda<int32_t, float>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include "ggml-cuda/mmq.cuh"
|
||||
#include "ggml-cuda/mmvf.cuh"
|
||||
#include "ggml-cuda/mmvq.cuh"
|
||||
#include "ggml-cuda/moe-expert-reduce.cuh"
|
||||
#include "ggml-cuda/norm.cuh"
|
||||
#include "ggml-cuda/opt-step-adamw.cuh"
|
||||
#include "ggml-cuda/opt-step-sgd.cuh"
|
||||
@@ -2527,12 +2528,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
ggml_cuda_op_trunc(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
ggml_cuda_op_expm1(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
ggml_cuda_op_softplus(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -2998,40 +2993,6 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
}
|
||||
#endif
|
||||
|
||||
static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
|
||||
const ggml_tensor * view,
|
||||
const ggml_tensor * set_rows) {
|
||||
|
||||
if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) {
|
||||
return false;
|
||||
}
|
||||
// ne3 not tested
|
||||
if (rope->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (set_rows->src[1]->type != GGML_TYPE_I64) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The view should flatten two dims of rope into one dim
|
||||
if (!ggml_is_contiguous(view) || view->ne[0] != rope->ne[0] * rope->ne[1]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Only norm/neox shaders have the fusion code
|
||||
const int mode = ((const int32_t *) rope->op_params)[2];
|
||||
if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
|
||||
#ifndef NDEBUG
|
||||
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
|
||||
@@ -3107,16 +3068,6 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
}
|
||||
}
|
||||
|
||||
if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
|
||||
const ggml_tensor * rope = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * view = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];
|
||||
|
||||
if (ggml_cuda_should_fuse_rope_set_rows(rope, view, set_rows)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||
return false;
|
||||
}
|
||||
@@ -3201,6 +3152,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
|
||||
#ifdef GGML_CUDA_DEBUG
|
||||
const int nodes_fused = i - prev_i - 1;
|
||||
prev_i = i;
|
||||
@@ -3246,13 +3199,29 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
|
||||
ggml_tensor * rope = cgraph->nodes[i];
|
||||
ggml_tensor * set_rows = cgraph->nodes[i + 2];
|
||||
if (node->op == GGML_OP_MUL) {
|
||||
int current_node = i + 1;
|
||||
int num_views = 0;
|
||||
int num_adds = 0;
|
||||
while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
|
||||
num_views++;
|
||||
current_node++;
|
||||
}
|
||||
|
||||
ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
|
||||
i += 2;
|
||||
continue;
|
||||
while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_ADD &&
|
||||
num_adds < num_views - 1) {
|
||||
num_adds++;
|
||||
current_node++;
|
||||
}
|
||||
|
||||
if (num_adds == num_views - 1 && num_views > 0) {
|
||||
ggml_tensor * dst_node = cgraph->nodes[current_node - 1];
|
||||
if (ggml_cuda_should_use_moe_expert_reduce(cgraph, i, current_node)) {
|
||||
ggml_cuda_op_moe_expert_reduce(*cuda_ctx, node->src[0], node->src[1], dst_node);
|
||||
i += num_views + num_adds;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD) {
|
||||
@@ -3333,13 +3302,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
continue;
|
||||
}
|
||||
|
||||
// we don't support repeating adds
|
||||
if (bias_op == GGML_OP_ADD &&
|
||||
(!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
|
||||
!ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = up_n->src[0];
|
||||
const ggml_tensor * src1 = up_n->src[1];
|
||||
const ggml_tensor * ids = up_n->src[2];
|
||||
@@ -3449,10 +3411,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
continue;
|
||||
}
|
||||
|
||||
if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.x_bias = bias_tensor;
|
||||
|
||||
@@ -3748,110 +3706,10 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
|
||||
return ctx->description.c_str();
|
||||
}
|
||||
|
||||
#if defined(__linux__)
|
||||
// Helper function to get available memory from /proc/meminfo for UMA systems
|
||||
static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_kb, long * free_swap_kb) {
|
||||
FILE * meminfo_file = nullptr;
|
||||
// 2KB buffer for reading /proc/meminfo since it does not report size info, should be enough
|
||||
const size_t BUFFER_SIZE = 2048;
|
||||
auto file_buffer = std::make_unique<char[]>(BUFFER_SIZE);
|
||||
size_t bytes_read = 0;
|
||||
long huge_tlb_total_pages = -1;
|
||||
long huge_tlb_free_pages = -1;
|
||||
long huge_tlb_page_size = -1;
|
||||
|
||||
if (available_memory_kb == nullptr || free_swap_kb == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
meminfo_file = fopen("/proc/meminfo", "r");
|
||||
if (meminfo_file == nullptr) {
|
||||
GGML_LOG_ERROR("%s: failed to open /proc/meminfo\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read file into buffer
|
||||
bytes_read = fread(file_buffer.get(), 1, BUFFER_SIZE - 1, meminfo_file);
|
||||
fclose(meminfo_file);
|
||||
|
||||
if (bytes_read == 0) {
|
||||
GGML_LOG_ERROR("%s: failed to read from /proc/meminfo\n", __func__);
|
||||
return false;
|
||||
}
|
||||
file_buffer[bytes_read] = '\0';
|
||||
|
||||
*available_memory_kb = -1;
|
||||
*free_swap_kb = -1;
|
||||
|
||||
// Parse the file buffer line by line
|
||||
char * line = file_buffer.get();
|
||||
char * line_next;
|
||||
while (line < file_buffer.get() + bytes_read) {
|
||||
// Find the end of the current line
|
||||
line_next = strchr(line, '\n');
|
||||
if (line_next != nullptr) {
|
||||
*line_next = '\0';
|
||||
line_next++;
|
||||
} else {
|
||||
line_next = file_buffer.get() + bytes_read;
|
||||
}
|
||||
|
||||
long value;
|
||||
if (sscanf(line, "MemAvailable: %ld kB", &value) == 1) {
|
||||
*available_memory_kb = value;
|
||||
} else if (sscanf(line, "SwapFree: %ld kB", &value) == 1) {
|
||||
*free_swap_kb = value;
|
||||
} else if (sscanf(line, "HugePages_Total: %ld", &value) == 1) {
|
||||
huge_tlb_total_pages = value;
|
||||
} else if (sscanf(line, "HugePages_Free: %ld", &value) == 1) {
|
||||
huge_tlb_free_pages = value;
|
||||
} else if (sscanf(line, "Hugepagesize: %ld kB", &value) == 1) {
|
||||
huge_tlb_page_size = value;
|
||||
}
|
||||
|
||||
line = line_next;
|
||||
}
|
||||
|
||||
if (huge_tlb_total_pages != 0 && huge_tlb_total_pages != -1) {
|
||||
*available_memory_kb = huge_tlb_free_pages * huge_tlb_page_size;
|
||||
|
||||
// Hugetlbfs pages are not swappable.
|
||||
*free_swap_kb = 0;
|
||||
}
|
||||
|
||||
GGML_LOG_DEBUG("%s: final available_memory_kb: %ld\n", __func__, *available_memory_kb);
|
||||
return true;
|
||||
}
|
||||
#endif // defined(__linux__)
|
||||
|
||||
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
ggml_cuda_set_device(ctx->device);
|
||||
CUDA_CHECK(cudaMemGetInfo(free, total));
|
||||
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/17368
|
||||
#if defined(__linux__)
|
||||
// Check if this is a UMA (Unified Memory Architecture) system
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device));
|
||||
|
||||
// Check if UMA is explicitly enabled via environment variable
|
||||
bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
|
||||
bool is_uma = prop.unifiedAddressing > 0 || uma_env;
|
||||
|
||||
if (is_uma) {
|
||||
// For UMA systems (like DGX Spark), use system memory info
|
||||
long available_memory_kb = 0;
|
||||
long free_swap_kb = 0;
|
||||
|
||||
if (ggml_backend_cuda_get_available_uma_memory(&available_memory_kb, &free_swap_kb) && available_memory_kb > 0) {
|
||||
*free = (size_t)available_memory_kb * 1024;
|
||||
} else {
|
||||
GGML_LOG_ERROR("%s: /proc/meminfo reading failed, using cudaMemGetInfo\n", __func__);
|
||||
}
|
||||
}
|
||||
#endif // defined(__linux__)
|
||||
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
|
||||
@@ -3939,8 +3797,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
@@ -4115,9 +3971,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -74,33 +74,6 @@ namespace ggml_cuda_mma {
|
||||
static constexpr int J = J_;
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
#if defined(RDNA4)
|
||||
static constexpr int ne = I * J / 32;
|
||||
T x[ne] = {0};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 16) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 16) {
|
||||
return 8 * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 16) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr int ne = I * J / 64;
|
||||
T x[ne] = {0};
|
||||
|
||||
@@ -146,7 +119,6 @@ namespace ggml_cuda_mma {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#endif // defined(RDNA4)
|
||||
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
static constexpr int ne = I * J / 32;
|
||||
T x[ne] = {0};
|
||||
@@ -264,32 +236,6 @@ namespace ggml_cuda_mma {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 32;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return 4 * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
@@ -339,34 +285,6 @@ namespace ggml_cuda_mma {
|
||||
struct tile<I_, J_, nv_bfloat162> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 32;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return 4 * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
@@ -402,7 +320,6 @@ namespace ggml_cuda_mma {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
};
|
||||
|
||||
template <int I, int J>
|
||||
@@ -436,8 +353,6 @@ namespace ggml_cuda_mma {
|
||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
||||
xi[0] = xs[0];
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int l = 0; l < t.ne; ++l) {
|
||||
@@ -724,34 +639,12 @@ namespace ggml_cuda_mma {
|
||||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
|
||||
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
||||
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
||||
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
|
||||
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // TURING_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
|
||||
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
||||
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
||||
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
|
||||
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMPERE_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
|
||||
@@ -129,13 +129,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||
if (src0_ne[0] % (warp_size * (4/ts)) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0_nb[0] != ts) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
|
||||
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||
if (src0_nb[i] % (2*ts) != 0) {
|
||||
return false;
|
||||
}
|
||||
@@ -151,7 +145,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
if (src1_ncols > 16) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -160,9 +154,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||
case GGML_TYPE_F32:
|
||||
return ampere_mma_available(cc);
|
||||
case GGML_TYPE_F16:
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
|
||||
return volta_mma_available(cc) || turing_mma_available(cc);
|
||||
case GGML_TYPE_BF16:
|
||||
return ampere_mma_available(cc) || amd_wmma_available(cc);
|
||||
return ampere_mma_available(cc);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
#include "mma.cuh"
|
||||
#include "common.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
@@ -28,35 +27,20 @@ static __global__ void mul_mat_f(
|
||||
const int stride_col_id, const int stride_row_id,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
|
||||
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<tile_B_I, 8, T> tile_B;
|
||||
typedef tile<16, tile_C_J, float> tile_C;
|
||||
|
||||
constexpr bool a_supported = tile_A::supported();
|
||||
constexpr bool b_supported = tile_B::supported();
|
||||
constexpr bool c_supported = tile_C::supported();
|
||||
constexpr bool supported = a_supported && b_supported && c_supported;
|
||||
#else
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||
constexpr bool supported = I_16_supported || I_32_supported;
|
||||
|
||||
if (!I_16_supported && !I_32_supported) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||
|
||||
typedef tile<I_preferred, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<I_preferred, 8, float> tile_C;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
if constexpr (!supported) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
@@ -177,11 +161,11 @@ static __global__ void mul_mat_f(
|
||||
|
||||
if constexpr (!has_ids) {
|
||||
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
} else {
|
||||
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
||||
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -255,7 +239,7 @@ static __global__ void mul_mat_f(
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
//This kernel is for larger batch sizes of mul_mat_id
|
||||
@@ -269,35 +253,20 @@ static __global__ void mul_mat_f_ids(
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const uint3 sis1_fd, const uint3 nch_fd) {
|
||||
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
|
||||
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<tile_B_I, 8, T> tile_B;
|
||||
typedef tile<16, tile_C_J, float> tile_C;
|
||||
|
||||
constexpr bool a_supported = tile_A::supported();
|
||||
constexpr bool b_supported = tile_B::supported();
|
||||
constexpr bool c_supported = tile_C::supported();
|
||||
constexpr bool supported = a_supported && b_supported && c_supported;
|
||||
#else
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||
constexpr bool supported = I_16_supported || I_32_supported;
|
||||
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||
if (!I_16_supported && !I_32_supported) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
|
||||
|
||||
typedef tile<I_preferred, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<I_preferred, 8, float> tile_C;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
if constexpr (!supported) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
@@ -439,7 +408,7 @@ static __global__ void mul_mat_f_ids(
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const float2 tmp = vals_buf[curr_buf][j0];
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
@@ -523,7 +492,7 @@ static __global__ void mul_mat_f_ids(
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
template<typename T, int cols_per_block, int nwarps>
|
||||
@@ -585,8 +554,7 @@ void mul_mat_f_cuda(
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||
typedef tile<16, 8, T> tile_A_16;
|
||||
typedef tile<32, 8, T> tile_A_32;
|
||||
typedef tile<16, 8, T> tile_B_16;
|
||||
typedef tile< 8, 8, T> tile_B_8;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
|
||||
GGML_ASSERT(ncols_x % 2 == 0);
|
||||
GGML_ASSERT(stride_row % 2 == 0);
|
||||
@@ -613,8 +581,7 @@ void mul_mat_f_cuda(
|
||||
|
||||
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
||||
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
||||
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
|
||||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
||||
|
||||
@@ -3494,7 +3494,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
||||
const int col_diff = col_high - col_low;
|
||||
|
||||
for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
|
||||
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
|
||||
ids_dst_shared[j] = ids_dst[col_low + j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
@@ -720,19 +720,12 @@ bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0
|
||||
if (src0_ne[0] % 2 != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t ts = ggml_type_size(type);
|
||||
if (src0_nb[0] != ts) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
|
||||
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||
if (src0_nb[i] % (2*ts) != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||
|
||||
168
ggml/src/ggml-cuda/moe-expert-reduce.cu
Normal file
168
ggml/src/ggml-cuda/moe-expert-reduce.cu
Normal file
@@ -0,0 +1,168 @@
|
||||
#include "moe-expert-reduce.cuh"
|
||||
|
||||
// This kernel is a fusion of the expert weight reduce, common in MoE models
|
||||
|
||||
template <int n_expert_used_template>
|
||||
__global__ void moe_expert_reduce_cuda(const float * __restrict__ experts,
|
||||
const float * __restrict__ weights,
|
||||
float * __restrict__ dst,
|
||||
const int n_expert_used,
|
||||
const int n_cols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
if (col >= n_cols) {
|
||||
return;
|
||||
}
|
||||
|
||||
experts += row * n_cols * n_expert_used;
|
||||
weights += row * n_expert_used;
|
||||
dst += row * n_cols;
|
||||
|
||||
float acc = 0.f;
|
||||
if constexpr (n_expert_used_template == 0) {
|
||||
for (int expert = 0; expert < n_expert_used; ++expert) {
|
||||
ggml_cuda_mad(acc, experts[col], weights[expert]);
|
||||
experts += n_cols;
|
||||
}
|
||||
dst[col] = acc;
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_expert_used_template; ++i) {
|
||||
ggml_cuda_mad(acc, experts[col], weights[i]);
|
||||
experts += n_cols;
|
||||
}
|
||||
dst[col] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
static void launch_moe_expert_reduce(ggml_backend_cuda_context & ctx,
|
||||
const float * experts,
|
||||
const float * weights,
|
||||
float * dst,
|
||||
const int n_expert_used,
|
||||
const int n_cols,
|
||||
const int n_rows) {
|
||||
const int block_size = 32;
|
||||
|
||||
const int n_blocks_x = n_rows;
|
||||
const int n_blocks_y = (n_cols + block_size - 1) / block_size;
|
||||
|
||||
dim3 block_dims(block_size);
|
||||
dim3 grid_dims(n_blocks_x, n_blocks_y);
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
switch (n_expert_used) {
|
||||
case 1:
|
||||
moe_expert_reduce_cuda<1>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 2:
|
||||
moe_expert_reduce_cuda<2>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 4:
|
||||
moe_expert_reduce_cuda<4>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 6:
|
||||
moe_expert_reduce_cuda<6>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 8:
|
||||
moe_expert_reduce_cuda<8>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 16:
|
||||
moe_expert_reduce_cuda<16>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 32:
|
||||
moe_expert_reduce_cuda<32>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 64:
|
||||
moe_expert_reduce_cuda<64>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 128:
|
||||
moe_expert_reduce_cuda<128>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
default:
|
||||
moe_expert_reduce_cuda<0>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index) {
|
||||
const ggml_tensor * mul = cgraph->nodes[start_index];
|
||||
|
||||
if (mul->op != GGML_OP_MUL || !ggml_is_contiguous(mul->src[0]) || !ggml_is_contiguous(mul->src[1])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int current_node = start_index + 1;
|
||||
size_t current_offset = 0;
|
||||
|
||||
std::vector<const ggml_tensor *> view_nodes;
|
||||
//check if all are views of the expert in increasing order
|
||||
while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
|
||||
const ggml_tensor * node = cgraph->nodes[current_node];
|
||||
if (node->view_src != mul) {
|
||||
return false;
|
||||
}
|
||||
if (node->view_offs < current_offset) {
|
||||
return false;
|
||||
}
|
||||
current_offset = node->view_offs;
|
||||
current_node++;
|
||||
view_nodes.push_back(node);
|
||||
}
|
||||
|
||||
//check if all the adds are in increasing order
|
||||
const ggml_tensor * prev_add_src = view_nodes.empty() ? nullptr : view_nodes[0];
|
||||
int num_adds = 0;
|
||||
int num_views = view_nodes.size();
|
||||
while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_ADD) {
|
||||
const ggml_tensor * add_node = cgraph->nodes[current_node];
|
||||
|
||||
bool is_first_op_ok = num_views > num_adds ? add_node->src[0] == prev_add_src : false;
|
||||
bool is_second_op_ok = num_views > num_adds ? add_node->src[1] == view_nodes[num_adds + 1] : false;
|
||||
|
||||
if (!is_first_op_ok || !is_second_op_ok) {
|
||||
return false;
|
||||
}
|
||||
prev_add_src = add_node;
|
||||
|
||||
num_adds++;
|
||||
current_node++;
|
||||
}
|
||||
|
||||
if (num_views != num_adds + 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * experts,
|
||||
const ggml_tensor * weights,
|
||||
ggml_tensor * dst) {
|
||||
const int n_rows = experts->ne[2];
|
||||
const int n_expert_used = experts->ne[1];
|
||||
const int n_cols = experts->ne[0];
|
||||
|
||||
GGML_ASSERT(experts->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(experts));
|
||||
GGML_ASSERT(ggml_is_contiguous(weights));
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const float * experts_d = (const float *) experts->data;
|
||||
const float * weights_d = (const float *) weights->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
launch_moe_expert_reduce(ctx, experts_d, weights_d, dst_d, n_expert_used, n_cols, n_rows);
|
||||
}
|
||||
11
ggml/src/ggml-cuda/moe-expert-reduce.cuh
Normal file
11
ggml/src/ggml-cuda/moe-expert-reduce.cuh
Normal file
@@ -0,0 +1,11 @@
|
||||
#include "common.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <initializer_list>
|
||||
|
||||
void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * experts,
|
||||
const ggml_tensor * weights,
|
||||
ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index);
|
||||
@@ -1,6 +1,3 @@
|
||||
#include "convert.cuh"
|
||||
#include "ggml-cuda/common.cuh"
|
||||
#include "ggml.h"
|
||||
#include "rope.cuh"
|
||||
|
||||
struct rope_corr_dims {
|
||||
@@ -40,23 +37,11 @@ static __device__ void rope_yarn(
|
||||
}
|
||||
}
|
||||
|
||||
template <bool forward, bool has_ff, typename T, typename D>
|
||||
static __global__ void rope_norm(const T * x,
|
||||
D * dst,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int n_dims,
|
||||
const int32_t * pos,
|
||||
const float freq_scale,
|
||||
const float ext_factor,
|
||||
const float attn_factor,
|
||||
const rope_corr_dims corr_dims,
|
||||
const float theta_scale,
|
||||
const float * freq_factors,
|
||||
const int64_t * row_indices,
|
||||
const int set_rows_stride) {
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_norm(
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
@@ -68,27 +53,13 @@ static __global__ void rope_norm(const T * x,
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
|
||||
int idst = row_dst * ne0 + i0;
|
||||
const int idst = row_dst*ne0 + i0;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0;
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
|
||||
if (set_rows_stride != 0) {
|
||||
idst = row_x * ne0 + i0;
|
||||
idst += row_indices[channel_x] * set_rows_stride;
|
||||
}
|
||||
|
||||
const auto & store_coaelsced = [&](float x0, float x1) {
|
||||
if constexpr (std::is_same_v<float, D>) {
|
||||
float2 v = make_float2(x0, x1);
|
||||
ggml_cuda_memcpy_1<8>(dst + idst, &v);
|
||||
} else if constexpr (std::is_same_v<half, D>) {
|
||||
half2 v = make_half2(x0, x1);
|
||||
ggml_cuda_memcpy_1<4>(dst + idst, &v);
|
||||
}
|
||||
};
|
||||
if (i0 >= n_dims) {
|
||||
store_coaelsced(x[ix + 0], x[ix + 1]);
|
||||
dst[idst + 0] = x[ix + 0];
|
||||
dst[idst + 1] = x[ix + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -104,26 +75,15 @@ static __global__ void rope_norm(const T * x,
|
||||
const float x0 = x[ix + 0];
|
||||
const float x1 = x[ix + 1];
|
||||
|
||||
store_coaelsced(x0 * cos_theta - x1 * sin_theta, x0 * sin_theta + x1 * cos_theta);
|
||||
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
|
||||
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
template <bool forward, bool has_ff, typename T, typename D>
|
||||
static __global__ void rope_neox(const T * x,
|
||||
D * dst,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int n_dims,
|
||||
const int32_t * pos,
|
||||
const float freq_scale,
|
||||
const float ext_factor,
|
||||
const float attn_factor,
|
||||
const rope_corr_dims corr_dims,
|
||||
const float theta_scale,
|
||||
const float * freq_factors,
|
||||
const int64_t * row_indices,
|
||||
const int set_rows_stride) {
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_neox(
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
@@ -135,19 +95,12 @@ static __global__ void rope_neox(const T * x,
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
|
||||
int idst = row_dst * ne0 + i0 / 2;
|
||||
const int idst = row_dst*ne0 + i0/2;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
|
||||
if (set_rows_stride != 0) {
|
||||
idst = row_x * ne0 + i0 / 2;
|
||||
idst += row_indices[channel_x] * set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + i0 / 2 + 0] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 0]);
|
||||
dst[idst + i0 / 2 + 1] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 1]);
|
||||
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
||||
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
@@ -164,8 +117,8 @@ static __global__ void rope_neox(const T * x,
|
||||
const float x0 = x[ix + 0];
|
||||
const float x1 = x[ix + n_dims/2];
|
||||
|
||||
dst[idst + 0] = ggml_cuda_cast<D>(x0 * cos_theta - x1 * sin_theta);
|
||||
dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);
|
||||
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
|
||||
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
@@ -285,25 +238,11 @@ static __global__ void rope_vision(
|
||||
dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
template <bool forward, typename T, typename D>
|
||||
static void rope_norm_cuda(const T * x,
|
||||
D * dst,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int n_dims,
|
||||
const int nr,
|
||||
const int32_t * pos,
|
||||
const float freq_scale,
|
||||
const float freq_base,
|
||||
const float ext_factor,
|
||||
const float attn_factor,
|
||||
const rope_corr_dims corr_dims,
|
||||
const float * freq_factors,
|
||||
const int64_t * row_indices,
|
||||
const int set_rows_stride,
|
||||
cudaStream_t stream) {
|
||||
template<bool forward, typename T>
|
||||
static void rope_norm_cuda(
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
@@ -313,34 +252,20 @@ static void rope_norm_cuda(const T * x,
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
||||
freq_factors, row_indices, set_rows_stride);
|
||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors);
|
||||
} else {
|
||||
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
||||
freq_factors, row_indices, set_rows_stride);
|
||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool forward, typename T, typename D>
|
||||
static void rope_neox_cuda(const T * x,
|
||||
D * dst,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int n_dims,
|
||||
const int nr,
|
||||
const int32_t * pos,
|
||||
const float freq_scale,
|
||||
const float freq_base,
|
||||
const float ext_factor,
|
||||
const float attn_factor,
|
||||
const rope_corr_dims corr_dims,
|
||||
const float * freq_factors,
|
||||
const int64_t * row_indices,
|
||||
const int set_rows_stride,
|
||||
cudaStream_t stream) {
|
||||
template<bool forward, typename T>
|
||||
static void rope_neox_cuda(
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
@@ -349,13 +274,13 @@ static void rope_neox_cuda(const T * x,
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
||||
freq_factors, row_indices, set_rows_stride);
|
||||
rope_neox<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors);
|
||||
} else {
|
||||
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
||||
freq_factors, row_indices, set_rows_stride);
|
||||
rope_neox<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -408,9 +333,7 @@ static void rope_vision_cuda(
|
||||
}
|
||||
|
||||
template <bool forward>
|
||||
void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
|
||||
ggml_tensor * dst,
|
||||
const ggml_tensor * set_rows = nullptr) {
|
||||
void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
@@ -418,25 +341,12 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
|
||||
void * dst_d = dst->data;
|
||||
const int64_t * row_indices = nullptr;
|
||||
ggml_type dst_type = dst->type;
|
||||
int set_rows_stride = 0;
|
||||
|
||||
if (set_rows != nullptr) {
|
||||
GGML_ASSERT(forward);
|
||||
dst_d = set_rows->data;
|
||||
row_indices = (const int64_t *) set_rows->src[1]->data;
|
||||
dst_type = set_rows->type;
|
||||
set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
|
||||
}
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
// When not fused, src0 and dst types must match
|
||||
// When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16
|
||||
GGML_ASSERT(src0->type == dst->type || (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
|
||||
const int64_t ne00 = src0->ne[0]; // head dims
|
||||
const int64_t ne01 = src0->ne[1]; // num heads
|
||||
@@ -494,18 +404,14 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
|
||||
|
||||
// compute
|
||||
if (is_neox) {
|
||||
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
||||
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
|
||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, row_indices, set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
||||
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
|
||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, row_indices, set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
||||
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
|
||||
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, row_indices, set_rows_stride, stream);
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_neox_cuda<forward>(
|
||||
(const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_neox_cuda<forward>(
|
||||
(const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@@ -534,18 +440,14 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
} else {
|
||||
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
||||
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
|
||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, row_indices, set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
||||
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
|
||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, row_indices, set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
||||
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
|
||||
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, row_indices, set_rows_stride, stream);
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_norm_cuda<forward>(
|
||||
(const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_norm_cuda<forward>(
|
||||
(const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@@ -559,7 +461,3 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_rope_impl<false>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) {
|
||||
ggml_cuda_op_rope_impl<true>(ctx, rope, set_rows);
|
||||
}
|
||||
|
||||
@@ -5,5 +5,3 @@
|
||||
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);
|
||||
|
||||
@@ -81,14 +81,6 @@ static __device__ __forceinline__ float op_log(float x) {
|
||||
return logf(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_expm1(float x) {
|
||||
return expm1f(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_softplus(float x) {
|
||||
return (x > 20.0f) ? x : logf(1.0f + expf(x));
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_elu(float x) {
|
||||
return (x > 0.f) ? x : expm1f(x);
|
||||
}
|
||||
@@ -241,14 +233,6 @@ void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_trunc>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_expm1>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_softplus>(ctx, dst);
|
||||
}
|
||||
/* gated ops */
|
||||
|
||||
template <float (*op)(float), typename T>
|
||||
|
||||
@@ -61,10 +61,6 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user