mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
Compare commits
117 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
803dac2e48 | ||
|
|
459c0c2c1a | ||
|
|
be79d9fdd9 | ||
|
|
f432d8d83e | ||
|
|
4067f07fc5 | ||
|
|
4b8560ab56 | ||
|
|
0dd58b6877 | ||
|
|
69ffd89163 | ||
|
|
246c0d9c79 | ||
|
|
3edd87cd05 | ||
|
|
c0b45097c3 | ||
|
|
38dbdf4c05 | ||
|
|
368560a1e3 | ||
|
|
4ca088b036 | ||
|
|
703f9e32c4 | ||
|
|
ad6bd9083b | ||
|
|
2b6b55a59f | ||
|
|
e58174cecb | ||
|
|
b213fce89b | ||
|
|
e00f3fd8ff | ||
|
|
f2f28380ea | ||
|
|
62c3b645c5 | ||
|
|
d304f459d8 | ||
|
|
0320ac5264 | ||
|
|
a7a98e0fff | ||
|
|
8f8f2274ee | ||
|
|
c959b676be | ||
|
|
cd08fc3ecc | ||
|
|
cb5bb6cc05 | ||
|
|
a91d035b90 | ||
|
|
745cbcf2fe | ||
|
|
1cbd80f8cf | ||
|
|
85286f3548 | ||
|
|
d5fabe3682 | ||
|
|
8ff206097c | ||
|
|
77475530b8 | ||
|
|
3913f8730e | ||
|
|
76888d202e | ||
|
|
f1fbffb5c0 | ||
|
|
51abc96bdc | ||
|
|
07808ebb07 | ||
|
|
6d758839ff | ||
|
|
3d4053f77f | ||
|
|
dc381aa9a6 | ||
|
|
10d197409b | ||
|
|
b907255f4b | ||
|
|
28c39da7c6 | ||
|
|
106220562a | ||
|
|
a68f31edd7 | ||
|
|
b8e09f08b9 | ||
|
|
6c019cb04e | ||
|
|
9dcd200d57 | ||
|
|
0fa154e350 | ||
|
|
261e6a20ff | ||
|
|
a0e13dcbe5 | ||
|
|
a14bd35014 | ||
|
|
918b26f197 | ||
|
|
9ecb884346 | ||
|
|
d1c6f11f47 | ||
|
|
6380d6a3e7 | ||
|
|
aa0c461efe | ||
|
|
b9c9c9f789 | ||
|
|
50f4281a6f | ||
|
|
55758b00ca | ||
|
|
f161463a54 | ||
|
|
84d7b2fca1 | ||
|
|
40be51152d | ||
|
|
4bf5549269 | ||
|
|
f4e664f838 | ||
|
|
f088b6a84f | ||
|
|
304ac5693d | ||
|
|
6c88ad8fa7 | ||
|
|
704d90c987 | ||
|
|
360d6533db | ||
|
|
0e6ff0046f | ||
|
|
df082f5630 | ||
|
|
24a6734daf | ||
|
|
2b3efea9a4 | ||
|
|
c0389dba43 | ||
|
|
00681dfc16 | ||
|
|
4f658855fa | ||
|
|
6ab397e12b | ||
|
|
9de447d94e | ||
|
|
0f0a3c2851 | ||
|
|
33daece86b | ||
|
|
e7b6d83b52 | ||
|
|
2cfef4d117 | ||
|
|
09e72a037c | ||
|
|
10d8b2b6b0 | ||
|
|
28b5f190ef | ||
|
|
86587da03b | ||
|
|
ff02caf9ee | ||
|
|
ae355f6f71 | ||
|
|
4f63cd705c | ||
|
|
17bc5a815f | ||
|
|
ed54e32558 | ||
|
|
a972faebed | ||
|
|
550cf726e1 | ||
|
|
c252ce67c4 | ||
|
|
70cd37dbbe | ||
|
|
acc1b008cf | ||
|
|
7057faf64b | ||
|
|
fe1c92cd7b | ||
|
|
e68aa10d8f | ||
|
|
0a16bf52e6 | ||
|
|
88021565f0 | ||
|
|
56920f5665 | ||
|
|
b0d52998b9 | ||
|
|
f28d4f4ac9 | ||
|
|
9fcb29f22f | ||
|
|
5ef22d281d | ||
|
|
233d773d02 | ||
|
|
a885dcff11 | ||
|
|
663027fd54 | ||
|
|
cf0e3ba150 | ||
|
|
d413dca003 | ||
|
|
85ca66a746 |
@@ -22,6 +22,13 @@ AllowShortIfStatementsOnASingleLine: Never
|
||||
AllowShortLambdasOnASingleLine: Inline
|
||||
AllowShortLoopsOnASingleLine: false
|
||||
AlwaysBreakBeforeMultilineStrings: true
|
||||
# Treat CUDA keywords/attributes as "attribute macros" and avoid breaking lines inside them
|
||||
AttributeMacros:
|
||||
- __host__
|
||||
- __device__
|
||||
- __global__
|
||||
- __forceinline__
|
||||
- __launch_bounds__
|
||||
BinPackArguments: true
|
||||
BinPackParameters: false # OnePerLine
|
||||
BitFieldColonSpacing: Both
|
||||
|
||||
@@ -4,7 +4,7 @@ ARG UBUNTU_VERSION=24.04
|
||||
ARG ROCM_VERSION=6.4
|
||||
ARG AMDGPU_VERSION=6.4
|
||||
|
||||
# Target the CUDA build image
|
||||
# Target the ROCm build image
|
||||
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
|
||||
|
||||
### Build image
|
||||
@@ -15,16 +15,13 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build
|
||||
# This is mostly tied to rocBLAS supported archs.
|
||||
# gfx803, gfx900, gfx1032, gfx1101, gfx1102,not officialy supported
|
||||
# gfx906 is deprecated
|
||||
#check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.2.4/reference/system-requirements.html
|
||||
#check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html
|
||||
|
||||
ARG ROCM_DOCKER_ARCH='gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102'
|
||||
#ARG ROCM_DOCKER_ARCH=gfx1100
|
||||
ARG ROCM_DOCKER_ARCH='gfx803;gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1010;gfx1030;gfx1032;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx1151'
|
||||
#ARG ROCM_DOCKER_ARCH='gfx1151'
|
||||
|
||||
# Set nvcc architectured
|
||||
# Set ROCm architectures
|
||||
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
|
||||
# Enable ROCm
|
||||
# ENV CC=/opt/rocm/llvm/bin/clang
|
||||
# ENV CXX=/opt/rocm/llvm/bin/clang++
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
@@ -39,8 +36,16 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN git clone https://github.com/rocm/rocwmma --branch develop --depth 1
|
||||
|
||||
RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
|
||||
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_BUILD_TESTS=OFF \
|
||||
cmake -S . -B build \
|
||||
-DGGML_HIP=ON \
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON \
|
||||
-DCMAKE_HIP_FLAGS="-I$(pwd)/rocwmma/library/include/" \
|
||||
-DAMDGPU_TARGETS="$ROCM_DOCKER_ARCH" \
|
||||
-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON \
|
||||
-DCMAKE_BUILD_TYPE=Release -DLLAMA_BUILD_TESTS=OFF \
|
||||
&& cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib \
|
||||
|
||||
@@ -52,3 +52,11 @@ insert_final_newline = unset
|
||||
[vendor/miniaudio/miniaudio.h]
|
||||
trim_trailing_whitespace = unset
|
||||
insert_final_newline = unset
|
||||
|
||||
[tools/server/webui/**]
|
||||
indent_style = unset
|
||||
indent_size = unset
|
||||
end_of_line = unset
|
||||
charset = unset
|
||||
trim_trailing_whitespace = unset
|
||||
insert_final_newline = unset
|
||||
|
||||
70
.github/workflows/build.yml
vendored
70
.github/workflows/build.yml
vendored
@@ -56,7 +56,7 @@ env:
|
||||
|
||||
jobs:
|
||||
macOS-latest-cmake-arm64:
|
||||
runs-on: macos-14
|
||||
runs-on: macos-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -88,6 +88,7 @@ jobs:
|
||||
-DGGML_METAL_SHADER_DEBUG=ON \
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
leaks -atExit -- ./build/bin/test-thread-safety -hf ggml-org/gemma-3-270m-qat-GGUF -ngl 99 -p "$(printf 'hello %.0s' {1..128})" -n 16 -c 512 -ub 32 -np 2 -t 2 -lv 1
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
@@ -126,7 +127,8 @@ jobs:
|
||||
-DCMAKE_BUILD_RPATH="@loader_path" \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DGGML_METAL=OFF \
|
||||
-DGGML_RPC=ON
|
||||
-DGGML_RPC=ON \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=13.3
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
|
||||
- name: Test
|
||||
@@ -136,7 +138,7 @@ jobs:
|
||||
ctest -L main --verbose --timeout 900
|
||||
|
||||
macOS-latest-cmake-arm64-webgpu:
|
||||
runs-on: macos-14
|
||||
runs-on: macos-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -709,6 +711,7 @@ jobs:
|
||||
|
||||
macOS-latest-swift:
|
||||
runs-on: macos-latest
|
||||
needs: ios-xcode-build
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -725,6 +728,12 @@ jobs:
|
||||
key: macOS-latest-swift
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Download xcframework artifact
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: llama-xcframework
|
||||
path: build-apple/llama.xcframework/
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
@@ -746,11 +755,6 @@ jobs:
|
||||
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64"
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
|
||||
- name: xcodebuild for swift package
|
||||
id: xcodebuild
|
||||
run: |
|
||||
./build-xcframework.sh
|
||||
|
||||
windows-msys2:
|
||||
runs-on: windows-2025
|
||||
|
||||
@@ -1050,9 +1054,13 @@ jobs:
|
||||
run: examples/sycl/win-build-sycl.bat
|
||||
|
||||
windows-latest-cmake-hip:
|
||||
if: ${{ github.event.inputs.create_release != 'true' }}
|
||||
runs-on: windows-2022
|
||||
|
||||
env:
|
||||
# The ROCm version must correspond to the version used in the HIP SDK.
|
||||
ROCM_VERSION: "6.4.2"
|
||||
HIPSDK_INSTALLER_VERSION: "25.Q3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
@@ -1061,23 +1069,46 @@ jobs:
|
||||
- name: Clone rocWMMA repository
|
||||
id: clone_rocwmma
|
||||
run: |
|
||||
git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1
|
||||
git clone https://github.com/rocm/rocwmma --branch rocm-${{ env.ROCM_VERSION }} --depth 1
|
||||
|
||||
- name: Install
|
||||
- name: Cache ROCm Installation
|
||||
id: cache-rocm
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: C:\Program Files\AMD\ROCm
|
||||
key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }}
|
||||
|
||||
- name: Install ROCm
|
||||
if: steps.cache-rocm.outputs.cache-hit != 'true'
|
||||
id: depends
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "Downloading AMD HIP SDK Installer"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
write-host "Installing AMD HIP SDK"
|
||||
$proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru
|
||||
$proc.WaitForExit(600000)
|
||||
$completed = $proc.WaitForExit(600000)
|
||||
if (-not $completed) {
|
||||
Write-Error "ROCm installation timed out after 10 minutes. Killing the process"
|
||||
$proc.Kill()
|
||||
exit 1
|
||||
}
|
||||
if ($proc.ExitCode -ne 0) {
|
||||
Write-Error "ROCm installation failed with exit code $($proc.ExitCode)"
|
||||
exit 1
|
||||
}
|
||||
write-host "Completed AMD HIP SDK installation"
|
||||
|
||||
- name: Verify ROCm
|
||||
id: verify
|
||||
run: |
|
||||
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
|
||||
# Find and test ROCm installation
|
||||
$clangPath = Get-ChildItem 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | Select-Object -First 1
|
||||
if (-not $clangPath) {
|
||||
Write-Error "ROCm installation not found"
|
||||
exit 1
|
||||
}
|
||||
& $clangPath.FullName --version
|
||||
|
||||
- name: Install ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1141,8 +1172,17 @@ jobs:
|
||||
run: |
|
||||
./build-xcframework.sh
|
||||
|
||||
- name: Upload xcframework artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: llama-xcframework
|
||||
path: build-apple/llama.xcframework/
|
||||
retention-days: 1
|
||||
|
||||
- name: Build Xcode project
|
||||
run: xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build
|
||||
run: |
|
||||
xcodebuild -downloadPlatform iOS
|
||||
xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build
|
||||
|
||||
android-build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
46
.github/workflows/release.yml
vendored
46
.github/workflows/release.yml
vendored
@@ -108,7 +108,8 @@ jobs:
|
||||
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DGGML_METAL=OFF \
|
||||
-DGGML_RPC=ON
|
||||
-DGGML_RPC=ON \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=13.3
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
|
||||
- name: Determine tag name
|
||||
@@ -528,11 +529,14 @@ jobs:
|
||||
windows-hip:
|
||||
runs-on: windows-2022
|
||||
|
||||
env:
|
||||
HIPSDK_INSTALLER_VERSION: "25.Q3"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- name: "radeon"
|
||||
gpu_targets: "gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
|
||||
gpu_targets: "gfx1151;gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -542,29 +546,52 @@ jobs:
|
||||
- name: Clone rocWMMA repository
|
||||
id: clone_rocwmma
|
||||
run: |
|
||||
git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1
|
||||
git clone https://github.com/rocm/rocwmma --branch develop --depth 1
|
||||
|
||||
- name: Cache ROCm Installation
|
||||
id: cache-rocm
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: C:\Program Files\AMD\ROCm
|
||||
key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }}
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: windows-latest-cmake-hip-${{ matrix.name }}-x64
|
||||
key: windows-latest-cmake-hip-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ matrix.name }}-x64
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Install
|
||||
- name: Install ROCm
|
||||
if: steps.cache-rocm.outputs.cache-hit != 'true'
|
||||
id: depends
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "Downloading AMD HIP SDK Installer"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
write-host "Installing AMD HIP SDK"
|
||||
$proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru
|
||||
$proc.WaitForExit(600000)
|
||||
$completed = $proc.WaitForExit(600000)
|
||||
if (-not $completed) {
|
||||
Write-Error "ROCm installation timed out after 10 minutes. Killing the process"
|
||||
$proc.Kill()
|
||||
exit 1
|
||||
}
|
||||
if ($proc.ExitCode -ne 0) {
|
||||
Write-Error "ROCm installation failed with exit code $($proc.ExitCode)"
|
||||
exit 1
|
||||
}
|
||||
write-host "Completed AMD HIP SDK installation"
|
||||
|
||||
- name: Verify ROCm
|
||||
id: verify
|
||||
run: |
|
||||
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
|
||||
# Find and test ROCm installation
|
||||
$clangPath = Get-ChildItem 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | Select-Object -First 1
|
||||
if (-not $clangPath) {
|
||||
Write-Error "ROCm installation not found"
|
||||
exit 1
|
||||
}
|
||||
& $clangPath.FullName --version
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -585,9 +612,12 @@ jobs:
|
||||
-DLLAMA_CURL=OFF
|
||||
cmake --build build --target ggml-hip -j ${env:NUMBER_OF_PROCESSORS}
|
||||
md "build\bin\rocblas\library\"
|
||||
md "build\bin\hipblaslt\library"
|
||||
cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\"
|
||||
cp "${env:HIP_PATH}\bin\hipblaslt.dll" "build\bin\"
|
||||
cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\"
|
||||
cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\"
|
||||
cp "${env:HIP_PATH}\bin\hipblaslt\library\*" "build\bin\hipblaslt\library\"
|
||||
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
|
||||
229
.github/workflows/server.yml
vendored
229
.github/workflows/server.yml
vendored
@@ -76,51 +76,206 @@ jobs:
|
||||
run: |
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
|
||||
# Setup nodejs (to be used for verifying bundled index.html)
|
||||
- uses: actions/setup-node@v4
|
||||
webui-setup:
|
||||
name: WebUI Setup
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
node-version: '22.11.0'
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: WebUI - Install dependencies
|
||||
id: webui_lint
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
cache-dependency-path: "tools/server/webui/package-lock.json"
|
||||
|
||||
- name: Cache node_modules
|
||||
uses: actions/cache@v4
|
||||
id: cache-node-modules
|
||||
with:
|
||||
path: tools/server/webui/node_modules
|
||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-node-modules-
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
||||
run: npm ci
|
||||
working-directory: tools/server/webui
|
||||
|
||||
webui-check:
|
||||
needs: webui-setup
|
||||
name: WebUI Check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Restore node_modules cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: tools/server/webui/node_modules
|
||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-node-modules-
|
||||
|
||||
- name: Run type checking
|
||||
run: npm run check
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run linting
|
||||
run: npm run lint
|
||||
working-directory: tools/server/webui
|
||||
|
||||
webui-build:
|
||||
needs: webui-check
|
||||
name: WebUI Build
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Restore node_modules cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: tools/server/webui/node_modules
|
||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-node-modules-
|
||||
|
||||
- name: Build application
|
||||
run: npm run build
|
||||
working-directory: tools/server/webui
|
||||
|
||||
webui-tests:
|
||||
needs: webui-build
|
||||
name: Run WebUI tests
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Restore node_modules cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: tools/server/webui/node_modules
|
||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-node-modules-
|
||||
|
||||
- name: Install Playwright browsers
|
||||
run: npx playwright install --with-deps
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Build Storybook
|
||||
run: npm run build-storybook
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run Client tests
|
||||
run: npm run test:client
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run Server tests
|
||||
run: npm run test:server
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run UI tests
|
||||
run: npm run test:ui
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run E2E tests
|
||||
run: npm run test:e2e
|
||||
working-directory: tools/server/webui
|
||||
|
||||
server-build:
|
||||
needs: [webui-tests]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken
|
||||
build_type: [RelWithDebInfo]
|
||||
include:
|
||||
- build_type: Release
|
||||
sanitizer: ""
|
||||
fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken
|
||||
|
||||
steps:
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
cd tools/server/webui
|
||||
npm ci
|
||||
sudo apt-get update
|
||||
sudo apt-get -y install \
|
||||
build-essential \
|
||||
xxd \
|
||||
git \
|
||||
cmake \
|
||||
curl \
|
||||
wget \
|
||||
language-pack-en \
|
||||
libcurl4-openssl-dev
|
||||
|
||||
- name: WebUI - Check code format
|
||||
id: webui_format
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Tests dependencies
|
||||
id: test_dependencies
|
||||
run: |
|
||||
git config --global --add safe.directory $(realpath .)
|
||||
cd tools/server/webui
|
||||
git status
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
|
||||
npm run format
|
||||
git status
|
||||
modified_files="$(git status -s)"
|
||||
echo "Modified files: ${modified_files}"
|
||||
if [ -n "${modified_files}" ]; then
|
||||
echo "Files do not follow coding style. To fix: npm run format"
|
||||
echo "${modified_files}"
|
||||
exit 1
|
||||
fi
|
||||
- name: Setup Node.js for WebUI
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
cache-dependency-path: "tools/server/webui/package-lock.json"
|
||||
|
||||
- name: Verify bundled index.html
|
||||
id: verify_server_index_html
|
||||
run: |
|
||||
git config --global --add safe.directory $(realpath .)
|
||||
cd tools/server/webui
|
||||
git status
|
||||
- name: Install WebUI dependencies
|
||||
run: npm ci
|
||||
working-directory: tools/server/webui
|
||||
|
||||
npm run build
|
||||
git status
|
||||
modified_files="$(git status -s)"
|
||||
echo "Modified files: ${modified_files}"
|
||||
if [ -n "${modified_files}" ]; then
|
||||
echo "Repository is dirty or server/webui is not built as expected"
|
||||
echo "Hint: You may need to follow Web UI build guide in server/README.md"
|
||||
echo "${modified_files}"
|
||||
exit 1
|
||||
fi
|
||||
- name: Build WebUI
|
||||
run: npm run build
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Build (no OpenMP)
|
||||
id: cmake_build_no_openmp
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -148,3 +148,7 @@ poetry.toml
|
||||
/run-vim.sh
|
||||
/run-chat.sh
|
||||
.ccache/
|
||||
|
||||
# Code Workspace
|
||||
*.code-workspace
|
||||
|
||||
|
||||
7
.windsurf/rules/css-architecture.md
Normal file
7
.windsurf/rules/css-architecture.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
trigger: manual
|
||||
---
|
||||
|
||||
#### Tailwind & CSS
|
||||
|
||||
- We are using Tailwind v4 which uses oklch colors so we now want to refer to the CSS vars directly, without wrapping it with any color function like `hsla/hsl`, `rgba` etc.
|
||||
48
.windsurf/rules/sveltekit-architecture.md
Normal file
48
.windsurf/rules/sveltekit-architecture.md
Normal file
@@ -0,0 +1,48 @@
|
||||
---
|
||||
trigger: manual
|
||||
---
|
||||
|
||||
# Coding rules
|
||||
|
||||
## Svelte & SvelteKit
|
||||
|
||||
### Services vs Stores Separation Pattern
|
||||
|
||||
#### `lib/services/` - Pure Business Logic
|
||||
|
||||
- **Purpose**: Stateless business logic and external communication
|
||||
- **Contains**:
|
||||
- API calls to external services (ApiService)
|
||||
- Pure business logic functions (ChatService, etc.)
|
||||
- **Rules**:
|
||||
- NO Svelte runes ($state, $derived, $effect)
|
||||
- NO reactive state management
|
||||
- Pure functions and classes only
|
||||
- Can import types but not stores
|
||||
- Focus on "how" - implementation details
|
||||
|
||||
#### `lib/stores/` - Reactive State Management
|
||||
|
||||
- **Purpose**: Svelte-specific reactive state with runes
|
||||
- **Contains**:
|
||||
- Reactive state classes with $state, $derived, $effect
|
||||
- Database operations (DatabaseStore)
|
||||
- UI-focused state management
|
||||
- Store orchestration logic
|
||||
- **Rules**:
|
||||
- USE Svelte runes for reactivity
|
||||
- Import and use services for business logic
|
||||
- NO direct database operations
|
||||
- NO direct API calls (use services)
|
||||
- Focus on "what" - reactive state for UI
|
||||
|
||||
#### Enforcement
|
||||
|
||||
- Services should be testable without Svelte
|
||||
- Stores should leverage Svelte's reactivity system
|
||||
- Clear separation: services handle data, stores handle state
|
||||
- Services can be reused across multiple stores
|
||||
|
||||
#### Misc
|
||||
|
||||
- Always use `let` for $derived state variables
|
||||
9
.windsurf/rules/tests.md
Normal file
9
.windsurf/rules/tests.md
Normal file
@@ -0,0 +1,9 @@
|
||||
---
|
||||
trigger: manual
|
||||
---
|
||||
|
||||
# Automated Tests
|
||||
|
||||
## General rules
|
||||
|
||||
- NEVER include any test code in the production code - we should always have it in a separate dedicated files
|
||||
7
.windsurf/rules/typescript-architecture.md
Normal file
7
.windsurf/rules/typescript-architecture.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
trigger: manual
|
||||
---
|
||||
|
||||
## TypeScript
|
||||
|
||||
- Add JSDocs for functions
|
||||
@@ -58,6 +58,12 @@ if (MSVC)
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/bigobj>")
|
||||
endif()
|
||||
|
||||
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
|
||||
set(LLAMA_TOOLS_INSTALL_DEFAULT OFF)
|
||||
else()
|
||||
set(LLAMA_TOOLS_INSTALL_DEFAULT ${LLAMA_STANDALONE})
|
||||
endif()
|
||||
|
||||
#
|
||||
# option list
|
||||
#
|
||||
@@ -82,6 +88,7 @@ option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT})
|
||||
|
||||
# 3rd party libs
|
||||
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
|
||||
|
||||
@@ -16,6 +16,9 @@
|
||||
- Use the following format for the squashed commit title: `<module> : <commit title> (#<issue_number>)`. For example: `utils : fix typo in utils.py (#1234)`
|
||||
- Optionally pick a `<module>` from here: https://github.com/ggml-org/llama.cpp/wiki/Modules
|
||||
- Consider adding yourself to [CODEOWNERS](CODEOWNERS)
|
||||
- Let authors, who are also collaborators, merge their own PRs
|
||||
- When merging a PR by a contributor, make sure you have a good understanding of the changes
|
||||
- Be mindful of maintenance: most of the work going into a feature happens after the PR is merged. If the PR author is not committed to contribute long-term, someone else needs to take responsibility (you)
|
||||
|
||||
# Coding guidelines
|
||||
|
||||
|
||||
25
ci/run.sh
25
ci/run.sh
@@ -45,7 +45,7 @@ SRC=`pwd`
|
||||
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON"
|
||||
|
||||
if [ ! -z ${GG_BUILD_METAL} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON -DGGML_METAL_USE_BF16=ON"
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_CUDA} ]; then
|
||||
@@ -270,7 +270,9 @@ function gg_run_ctest_with_model_debug {
|
||||
local model; model=$(gg_get_model)
|
||||
cd build-ci-debug
|
||||
set -e
|
||||
|
||||
(LLAMACPP_TEST_MODELFILE="$model" time ctest --output-on-failure -L model) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
||||
|
||||
set +e
|
||||
cd ..
|
||||
}
|
||||
@@ -281,7 +283,15 @@ function gg_run_ctest_with_model_release {
|
||||
local model; model=$(gg_get_model)
|
||||
cd build-ci-release
|
||||
set -e
|
||||
|
||||
(LLAMACPP_TEST_MODELFILE="$model" time ctest --output-on-failure -L model) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
||||
|
||||
# test memory leaks
|
||||
#if [[ ! -z ${GG_BUILD_METAL} ]]; then
|
||||
# # TODO: this hangs for some reason ...
|
||||
# (time leaks -quiet -atExit -- ./bin/test-thread-safety -m $model --parallel 2 -t 2 -p "hello") 2>&1 | tee -a $OUT/${ci}-leaks.log
|
||||
#fi
|
||||
|
||||
set +e
|
||||
cd ..
|
||||
}
|
||||
@@ -860,10 +870,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||
fi
|
||||
|
||||
ret=0
|
||||
if [ -z ${GG_BUILD_SYCL} ]; then
|
||||
# SYCL build breaks with debug build flags
|
||||
test $ret -eq 0 && gg_run ctest_debug
|
||||
fi
|
||||
test $ret -eq 0 && gg_run ctest_debug
|
||||
test $ret -eq 0 && gg_run ctest_release
|
||||
|
||||
if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||
@@ -871,9 +878,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||
test $ret -eq 0 && gg_run rerank_tiny
|
||||
|
||||
if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
|
||||
if [ -z ${GG_BUILD_SYCL} ]; then
|
||||
test $ret -eq 0 && gg_run test_scripts_debug
|
||||
fi
|
||||
test $ret -eq 0 && gg_run test_scripts_debug
|
||||
test $ret -eq 0 && gg_run test_scripts_release
|
||||
fi
|
||||
|
||||
@@ -884,9 +889,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||
test $ret -eq 0 && gg_run pythia_2_8b
|
||||
#test $ret -eq 0 && gg_run open_llama_7b_v2
|
||||
fi
|
||||
if [ -z ${GG_BUILD_SYCL} ]; then
|
||||
test $ret -eq 0 && gg_run ctest_with_model_debug
|
||||
fi
|
||||
test $ret -eq 0 && gg_run ctest_with_model_debug
|
||||
test $ret -eq 0 && gg_run ctest_with_model_release
|
||||
fi
|
||||
fi
|
||||
|
||||
644
common/arg.cpp
644
common/arg.cpp
@@ -57,12 +57,32 @@ static std::string read_file(const std::string & fname) {
|
||||
}
|
||||
|
||||
static void write_file(const std::string & fname, const std::string & content) {
|
||||
std::ofstream file(fname);
|
||||
const std::string fname_tmp = fname + ".tmp";
|
||||
std::ofstream file(fname_tmp);
|
||||
if (!file) {
|
||||
throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
|
||||
}
|
||||
file << content;
|
||||
file.close();
|
||||
|
||||
try {
|
||||
file << content;
|
||||
file.close();
|
||||
|
||||
// Makes write atomic
|
||||
if (rename(fname_tmp.c_str(), fname.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, fname_tmp.c_str(), fname.c_str());
|
||||
// If rename fails, try to delete the temporary file
|
||||
if (remove(fname_tmp.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
|
||||
}
|
||||
}
|
||||
} catch (...) {
|
||||
// If anything fails, try to delete the temporary file
|
||||
if (remove(fname_tmp.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
|
||||
}
|
||||
|
||||
throw std::runtime_error(string_format("error: failed to write file '%s'\n", fname.c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) {
|
||||
@@ -217,250 +237,294 @@ struct curl_slist_ptr {
|
||||
}
|
||||
};
|
||||
|
||||
#define CURL_MAX_RETRY 3
|
||||
#define CURL_RETRY_DELAY_SECONDS 2
|
||||
|
||||
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds, const char * method_name) {
|
||||
int remaining_attempts = max_attempts;
|
||||
|
||||
while (remaining_attempts > 0) {
|
||||
LOG_INF("%s: %s %s (attempt %d of %d)...\n", __func__ , method_name, url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
|
||||
|
||||
CURLcode res = curl_easy_perform(curl);
|
||||
if (res == CURLE_OK) {
|
||||
return true;
|
||||
}
|
||||
|
||||
int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
|
||||
LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
|
||||
|
||||
remaining_attempts--;
|
||||
if (remaining_attempts == 0) break;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
|
||||
static CURLcode common_curl_perf(CURL * curl) {
|
||||
CURLcode res = curl_easy_perform(curl);
|
||||
if (res != CURLE_OK) {
|
||||
LOG_ERR("%s: curl_easy_perform() failed\n", __func__);
|
||||
}
|
||||
|
||||
LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
|
||||
|
||||
return false;
|
||||
return res;
|
||||
}
|
||||
|
||||
// download one single file from remote URL to local path
|
||||
static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) {
|
||||
// Check if the file already exists locally
|
||||
auto file_exists = std::filesystem::exists(path);
|
||||
|
||||
// If the file exists, check its JSON metadata companion file.
|
||||
std::string metadata_path = path + ".json";
|
||||
nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
|
||||
// Send a HEAD request to retrieve the etag and last-modified headers
|
||||
struct common_load_model_from_url_headers {
|
||||
std::string etag;
|
||||
std::string last_modified;
|
||||
std::string accept_ranges;
|
||||
};
|
||||
|
||||
if (file_exists) {
|
||||
if (offline) {
|
||||
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
|
||||
return true; // skip verification/downloading
|
||||
struct FILE_deleter {
|
||||
void operator()(FILE * f) const { fclose(f); }
|
||||
};
|
||||
|
||||
static size_t common_header_callback(char * buffer, size_t, size_t n_items, void * userdata) {
|
||||
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
|
||||
static std::regex header_regex("([^:]+): (.*)\r\n");
|
||||
static std::regex etag_regex("ETag", std::regex_constants::icase);
|
||||
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
|
||||
static std::regex accept_ranges_regex("Accept-Ranges", std::regex_constants::icase);
|
||||
std::string header(buffer, n_items);
|
||||
std::smatch match;
|
||||
if (std::regex_match(header, match, header_regex)) {
|
||||
const std::string & key = match[1];
|
||||
const std::string & value = match[2];
|
||||
if (std::regex_match(key, match, etag_regex)) {
|
||||
headers->etag = value;
|
||||
} else if (std::regex_match(key, match, last_modified_regex)) {
|
||||
headers->last_modified = value;
|
||||
} else if (std::regex_match(key, match, accept_ranges_regex)) {
|
||||
headers->accept_ranges = value;
|
||||
}
|
||||
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
|
||||
std::ifstream metadata_in(metadata_path);
|
||||
if (metadata_in.good()) {
|
||||
try {
|
||||
metadata_in >> metadata;
|
||||
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
|
||||
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
|
||||
etag = metadata.at("etag");
|
||||
}
|
||||
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
|
||||
last_modified = metadata.at("lastModified");
|
||||
}
|
||||
} catch (const nlohmann::json::exception & e) {
|
||||
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
||||
}
|
||||
}
|
||||
// if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again)
|
||||
} else {
|
||||
if (offline) {
|
||||
LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
}
|
||||
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
||||
}
|
||||
|
||||
// Send a HEAD request to retrieve the etag and last-modified headers
|
||||
struct common_load_model_from_url_headers {
|
||||
std::string etag;
|
||||
std::string last_modified;
|
||||
};
|
||||
return n_items;
|
||||
}
|
||||
|
||||
common_load_model_from_url_headers headers;
|
||||
bool head_request_ok = false;
|
||||
bool should_download = !file_exists; // by default, we should download if the file does not exist
|
||||
static size_t common_write_callback(void * data, size_t size, size_t nmemb, void * fd) {
|
||||
return std::fwrite(data, size, nmemb, static_cast<FILE *>(fd));
|
||||
}
|
||||
|
||||
// Initialize libcurl
|
||||
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
|
||||
curl_slist_ptr http_headers;
|
||||
// helper function to hide password in URL
|
||||
static std::string llama_download_hide_password_in_url(const std::string & url) {
|
||||
// Use regex to match and replace the user[:password]@ pattern in URLs
|
||||
// Pattern: scheme://[user[:password]@]host[...]
|
||||
static const std::regex url_regex(R"(^(?:[A-Za-z][A-Za-z0-9+.-]://)(?:[^/@]+@)?.$)");
|
||||
std::smatch match;
|
||||
|
||||
if (std::regex_match(url, match, url_regex)) {
|
||||
// match[1] = scheme (e.g., "https://")
|
||||
// match[2] = user[:password]@ part
|
||||
// match[3] = rest of URL (host and path)
|
||||
return match[1].str() + "********@" + match[3].str();
|
||||
}
|
||||
|
||||
return url; // No credentials found or malformed URL
|
||||
}
|
||||
|
||||
static void common_curl_easy_setopt_head(CURL * curl, const std::string & url) {
|
||||
// Set the URL, allow to follow http redirection
|
||||
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
|
||||
|
||||
# if defined(_WIN32)
|
||||
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
|
||||
// operating system. Currently implemented under MS-Windows.
|
||||
curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
||||
# endif
|
||||
|
||||
curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
|
||||
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress
|
||||
curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, common_header_callback);
|
||||
}
|
||||
|
||||
static void common_curl_easy_setopt_get(CURL * curl) {
|
||||
curl_easy_setopt(curl, CURLOPT_NOBODY, 0L);
|
||||
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, common_write_callback);
|
||||
|
||||
// display download progress
|
||||
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
|
||||
}
|
||||
|
||||
static bool common_pull_file(CURL * curl, const std::string & path_temporary) {
|
||||
if (std::filesystem::exists(path_temporary)) {
|
||||
const std::string partial_size = std::to_string(std::filesystem::file_size(path_temporary));
|
||||
LOG_INF("%s: server supports range requests, resuming download from byte %s\n", __func__, partial_size.c_str());
|
||||
const std::string range_str = partial_size + "-";
|
||||
curl_easy_setopt(curl, CURLOPT_RANGE, range_str.c_str());
|
||||
}
|
||||
|
||||
// Always open file in append mode could be resuming
|
||||
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "ab"));
|
||||
if (!outfile) {
|
||||
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_temporary.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
common_curl_easy_setopt_get(curl);
|
||||
curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile.get());
|
||||
|
||||
return common_curl_perf(curl) == CURLE_OK;
|
||||
}
|
||||
|
||||
static bool common_download_head(CURL * curl,
|
||||
curl_slist_ptr & http_headers,
|
||||
const std::string & url,
|
||||
const std::string & bearer_token) {
|
||||
if (!curl) {
|
||||
LOG_ERR("%s: error initializing libcurl\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Set the URL, allow to follow http redirection
|
||||
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
|
||||
|
||||
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
|
||||
// Check if hf-token or bearer-token was specified
|
||||
if (!bearer_token.empty()) {
|
||||
std::string auth_header = "Authorization: Bearer " + bearer_token;
|
||||
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
|
||||
}
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
|
||||
|
||||
#if defined(_WIN32)
|
||||
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
|
||||
// operating system. Currently implemented under MS-Windows.
|
||||
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
||||
#endif
|
||||
|
||||
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
|
||||
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
|
||||
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
|
||||
|
||||
static std::regex header_regex("([^:]+): (.*)\r\n");
|
||||
static std::regex etag_regex("ETag", std::regex_constants::icase);
|
||||
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
|
||||
|
||||
std::string header(buffer, n_items);
|
||||
std::smatch match;
|
||||
if (std::regex_match(header, match, header_regex)) {
|
||||
const std::string & key = match[1];
|
||||
const std::string & value = match[2];
|
||||
if (std::regex_match(key, match, etag_regex)) {
|
||||
headers->etag = value;
|
||||
} else if (std::regex_match(key, match, last_modified_regex)) {
|
||||
headers->last_modified = value;
|
||||
}
|
||||
}
|
||||
return n_items;
|
||||
};
|
||||
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
|
||||
|
||||
// we only allow retrying once for HEAD requests
|
||||
// this is for the use case of using running offline (no internet), retrying can be annoying
|
||||
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD");
|
||||
if (!was_perform_successful) {
|
||||
head_request_ok = false;
|
||||
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
|
||||
}
|
||||
|
||||
long http_code = 0;
|
||||
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
||||
if (http_code == 200) {
|
||||
head_request_ok = true;
|
||||
} else {
|
||||
LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
|
||||
head_request_ok = false;
|
||||
}
|
||||
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, http_headers.ptr);
|
||||
common_curl_easy_setopt_head(curl, url);
|
||||
return common_curl_perf(curl) == CURLE_OK;
|
||||
}
|
||||
|
||||
// if head_request_ok is false, we don't have the etag or last-modified headers
|
||||
// we leave should_download as-is, which is true if the file does not exist
|
||||
if (head_request_ok) {
|
||||
// check if ETag or Last-Modified headers are different
|
||||
// if it is, we need to download the file again
|
||||
if (!etag.empty() && etag != headers.etag) {
|
||||
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
|
||||
should_download = true;
|
||||
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
|
||||
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
|
||||
should_download = true;
|
||||
}
|
||||
}
|
||||
// download one single file from remote URL to local path
|
||||
static bool common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
bool offline) {
|
||||
// If the file exists, check its JSON metadata companion file.
|
||||
std::string metadata_path = path + ".json";
|
||||
static const int max_attempts = 3;
|
||||
static const int retry_delay_seconds = 2;
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
|
||||
std::string etag;
|
||||
std::string last_modified;
|
||||
|
||||
if (should_download) {
|
||||
std::string path_temporary = path + ".downloadInProgress";
|
||||
// Check if the file already exists locally
|
||||
const auto file_exists = std::filesystem::exists(path);
|
||||
if (file_exists) {
|
||||
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
if (offline) {
|
||||
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
|
||||
return true; // skip verification/downloading
|
||||
}
|
||||
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
|
||||
std::ifstream metadata_in(metadata_path);
|
||||
if (metadata_in.good()) {
|
||||
try {
|
||||
metadata_in >> metadata;
|
||||
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
|
||||
metadata.dump().c_str());
|
||||
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
|
||||
etag = metadata.at("etag");
|
||||
}
|
||||
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
|
||||
last_modified = metadata.at("lastModified");
|
||||
}
|
||||
} catch (const nlohmann::json::exception & e) {
|
||||
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
||||
}
|
||||
}
|
||||
// if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again)
|
||||
} else {
|
||||
if (offline) {
|
||||
LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
}
|
||||
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
||||
}
|
||||
|
||||
// Set the output file
|
||||
bool head_request_ok = false;
|
||||
bool should_download = !file_exists; // by default, we should download if the file does not exist
|
||||
|
||||
struct FILE_deleter {
|
||||
void operator()(FILE * f) const {
|
||||
fclose(f);
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
|
||||
if (!outfile) {
|
||||
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
|
||||
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
|
||||
return fwrite(data, size, nmemb, (FILE *)fd);
|
||||
};
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
|
||||
|
||||
// display download progress
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
|
||||
|
||||
// helper function to hide password in URL
|
||||
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
|
||||
std::size_t protocol_pos = url.find("://");
|
||||
if (protocol_pos == std::string::npos) {
|
||||
return url; // Malformed URL
|
||||
}
|
||||
|
||||
std::size_t at_pos = url.find('@', protocol_pos + 3);
|
||||
if (at_pos == std::string::npos) {
|
||||
return url; // No password in URL
|
||||
}
|
||||
|
||||
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
|
||||
};
|
||||
|
||||
// start the download
|
||||
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
|
||||
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
|
||||
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET");
|
||||
// Initialize libcurl
|
||||
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
|
||||
common_load_model_from_url_headers headers;
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
|
||||
curl_slist_ptr http_headers;
|
||||
const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
|
||||
if (!was_perform_successful) {
|
||||
return false;
|
||||
head_request_ok = false;
|
||||
}
|
||||
|
||||
long http_code = 0;
|
||||
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
||||
if (http_code < 200 || http_code >= 400) {
|
||||
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
|
||||
return false;
|
||||
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
||||
if (http_code == 200) {
|
||||
head_request_ok = true;
|
||||
} else {
|
||||
LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
|
||||
head_request_ok = false;
|
||||
}
|
||||
|
||||
// Causes file to be closed explicitly here before we rename it.
|
||||
outfile.reset();
|
||||
|
||||
// Write the updated JSON metadata file.
|
||||
metadata.update({
|
||||
{"url", url},
|
||||
{"etag", headers.etag},
|
||||
{"lastModified", headers.last_modified}
|
||||
});
|
||||
write_file(metadata_path, metadata.dump(4));
|
||||
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
|
||||
|
||||
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return false;
|
||||
// if head_request_ok is false, we don't have the etag or last-modified headers
|
||||
// we leave should_download as-is, which is true if the file does not exist
|
||||
bool should_download_from_scratch = false;
|
||||
if (head_request_ok) {
|
||||
// check if ETag or Last-Modified headers are different
|
||||
// if it is, we need to download the file again
|
||||
if (!etag.empty() && etag != headers.etag) {
|
||||
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(),
|
||||
headers.etag.c_str());
|
||||
should_download = true;
|
||||
should_download_from_scratch = true;
|
||||
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
|
||||
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
|
||||
last_modified.c_str(), headers.last_modified.c_str());
|
||||
should_download = true;
|
||||
should_download_from_scratch = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
|
||||
|
||||
const bool accept_ranges_supported = !headers.accept_ranges.empty() && headers.accept_ranges != "none";
|
||||
if (should_download) {
|
||||
if (file_exists &&
|
||||
!accept_ranges_supported) { // Resumable downloads not supported, delete and start again.
|
||||
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string path_temporary = path + ".downloadInProgress";
|
||||
if (should_download_from_scratch) {
|
||||
if (std::filesystem::exists(path_temporary)) {
|
||||
if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (std::filesystem::exists(path)) {
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write the updated JSON metadata file.
|
||||
metadata.update({
|
||||
{ "url", url },
|
||||
{ "etag", headers.etag },
|
||||
{ "lastModified", headers.last_modified }
|
||||
});
|
||||
write_file(metadata_path, metadata.dump(4));
|
||||
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
|
||||
|
||||
// start the download
|
||||
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
|
||||
__func__, llama_download_hide_password_in_url(url).c_str(), path_temporary.c_str(),
|
||||
headers.etag.c_str(), headers.last_modified.c_str());
|
||||
const bool was_pull_successful = common_pull_file(curl.get(), path_temporary);
|
||||
if (!was_pull_successful) {
|
||||
if (i + 1 < max_attempts) {
|
||||
const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
|
||||
LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
|
||||
} else {
|
||||
LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
long http_code = 0;
|
||||
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
||||
if (http_code < 200 || http_code >= 400) {
|
||||
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -745,6 +809,124 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
|
||||
|
||||
#endif // LLAMA_USE_CURL
|
||||
|
||||
//
|
||||
// Docker registry functions
|
||||
//
|
||||
|
||||
static std::string common_docker_get_token(const std::string & repo) {
|
||||
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
|
||||
|
||||
common_remote_params params;
|
||||
auto res = common_remote_get_content(url, params);
|
||||
|
||||
if (res.first != 200) {
|
||||
throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
|
||||
}
|
||||
|
||||
std::string response_str(res.second.begin(), res.second.end());
|
||||
nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
|
||||
|
||||
if (!response.contains("token")) {
|
||||
throw std::runtime_error("Docker registry token response missing 'token' field");
|
||||
}
|
||||
|
||||
return response["token"].get<std::string>();
|
||||
}
|
||||
|
||||
static std::string common_docker_resolve_model(const std::string & docker) {
|
||||
// Parse ai/smollm2:135M-Q4_0
|
||||
size_t colon_pos = docker.find(':');
|
||||
std::string repo, tag;
|
||||
if (colon_pos != std::string::npos) {
|
||||
repo = docker.substr(0, colon_pos);
|
||||
tag = docker.substr(colon_pos + 1);
|
||||
} else {
|
||||
repo = docker;
|
||||
tag = "latest";
|
||||
}
|
||||
|
||||
// ai/ is the default
|
||||
size_t slash_pos = docker.find('/');
|
||||
if (slash_pos == std::string::npos) {
|
||||
repo.insert(0, "ai/");
|
||||
}
|
||||
|
||||
LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
|
||||
try {
|
||||
// --- helper: digest validation ---
|
||||
auto validate_oci_digest = [](const std::string & digest) -> std::string {
|
||||
// Expected: algo:hex ; start with sha256 (64 hex chars)
|
||||
// You can extend this map if supporting other algorithms in future.
|
||||
static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
|
||||
std::smatch m;
|
||||
if (!std::regex_match(digest, m, re)) {
|
||||
throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
|
||||
}
|
||||
// normalize hex to lowercase
|
||||
std::string normalized = digest;
|
||||
std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
|
||||
return std::tolower(c);
|
||||
});
|
||||
return normalized;
|
||||
};
|
||||
|
||||
std::string token = common_docker_get_token(repo); // Get authentication token
|
||||
|
||||
// Get manifest
|
||||
const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
|
||||
std::string manifest_url = url_prefix + "/manifests/" + tag;
|
||||
common_remote_params manifest_params;
|
||||
manifest_params.headers.push_back("Authorization: Bearer " + token);
|
||||
manifest_params.headers.push_back(
|
||||
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
|
||||
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
|
||||
if (manifest_res.first != 200) {
|
||||
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
|
||||
}
|
||||
|
||||
std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
|
||||
nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
|
||||
std::string gguf_digest; // Find the GGUF layer
|
||||
if (manifest.contains("layers")) {
|
||||
for (const auto & layer : manifest["layers"]) {
|
||||
if (layer.contains("mediaType")) {
|
||||
std::string media_type = layer["mediaType"].get<std::string>();
|
||||
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
|
||||
media_type.find("gguf") != std::string::npos) {
|
||||
gguf_digest = layer["digest"].get<std::string>();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (gguf_digest.empty()) {
|
||||
throw std::runtime_error("No GGUF layer found in Docker manifest");
|
||||
}
|
||||
|
||||
// Validate & normalize digest
|
||||
gguf_digest = validate_oci_digest(gguf_digest);
|
||||
LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
|
||||
|
||||
// Prepare local filename
|
||||
std::string model_filename = repo;
|
||||
std::replace(model_filename.begin(), model_filename.end(), '/', '_');
|
||||
model_filename += "_" + tag + ".gguf";
|
||||
std::string local_path = fs_get_cache_file(model_filename);
|
||||
|
||||
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
|
||||
if (!common_download_file_single(blob_url, local_path, token, false)) {
|
||||
throw std::runtime_error("Failed to download Docker Model");
|
||||
}
|
||||
|
||||
LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
|
||||
return local_path;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// utils
|
||||
//
|
||||
@@ -795,7 +977,9 @@ static handle_model_result common_params_handle_model(
|
||||
handle_model_result result;
|
||||
// handle pre-fill default model path and url based on hf_repo and hf_file
|
||||
{
|
||||
if (!model.hf_repo.empty()) {
|
||||
if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths
|
||||
model.path = common_docker_resolve_model(model.docker_repo);
|
||||
} else if (!model.hf_repo.empty()) {
|
||||
// short-hand to avoid specifying --hf-file -> default it to --model
|
||||
if (model.hf_file.empty()) {
|
||||
if (model.path.empty()) {
|
||||
@@ -1184,7 +1368,7 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
|
||||
} else {
|
||||
for (const auto & device : dev_names) {
|
||||
auto * dev = ggml_backend_dev_by_name(device.c_str());
|
||||
if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) {
|
||||
if (!dev || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||
throw std::invalid_argument(string_format("invalid device: %s", device.c_str()));
|
||||
}
|
||||
devices.push_back(dev);
|
||||
@@ -1194,7 +1378,7 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
|
||||
return devices;
|
||||
}
|
||||
|
||||
static void add_rpc_devices(std::string servers) {
|
||||
static void add_rpc_devices(const std::string & servers) {
|
||||
auto rpc_servers = string_split<std::string>(servers, ',');
|
||||
if (rpc_servers.empty()) {
|
||||
throw std::invalid_argument("no RPC servers specified");
|
||||
@@ -1584,7 +1768,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.system_prompt = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_DIFFUSION}));
|
||||
add_opt(common_arg(
|
||||
{"--no-perf"},
|
||||
string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"),
|
||||
@@ -2396,24 +2580,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--list-devices"},
|
||||
"print list of available devices and exit",
|
||||
[](common_params &) {
|
||||
std::vector<ggml_backend_dev_t> rpc_devices;
|
||||
std::vector<ggml_backend_dev_t> all_devices;
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
||||
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
|
||||
if (ggml_backend_reg_name(reg) == std::string("RPC")) {
|
||||
rpc_devices.push_back(dev);
|
||||
} else {
|
||||
all_devices.push_back(dev);
|
||||
}
|
||||
if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||
devices.push_back(dev);
|
||||
}
|
||||
}
|
||||
// insert RPC devices in front
|
||||
all_devices.insert(all_devices.begin(), rpc_devices.begin(), rpc_devices.end());
|
||||
printf("Available devices:\n");
|
||||
for (size_t i = 0; i < all_devices.size(); ++i) {
|
||||
auto * dev = all_devices[i];
|
||||
for (auto * dev : devices) {
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
@@ -2437,7 +2612,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--cpu-moe", "-cmoe"},
|
||||
"keep all Mixture of Experts (MoE) weights in the CPU",
|
||||
[](common_params & params) {
|
||||
params.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
|
||||
params.tensor_buft_overrides.push_back(llm_ffn_exps_cpu_override());
|
||||
}
|
||||
).set_env("LLAMA_ARG_CPU_MOE"));
|
||||
add_opt(common_arg(
|
||||
@@ -2450,7 +2625,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
for (int i = 0; i < value; ++i) {
|
||||
// keep strings alive and avoid leaking memory by storing them in a static vector
|
||||
static std::list<std::string> buft_overrides;
|
||||
buft_overrides.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
|
||||
buft_overrides.push_back(llm_ffn_exps_block_regex(i));
|
||||
params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), ggml_backend_cpu_buffer_type()});
|
||||
}
|
||||
}
|
||||
@@ -2459,7 +2634,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--cpu-moe-draft", "-cmoed"},
|
||||
"keep all Mixture of Experts (MoE) weights in the CPU for the draft model",
|
||||
[](common_params & params) {
|
||||
params.speculative.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
|
||||
params.speculative.tensor_buft_overrides.push_back(llm_ffn_exps_cpu_override());
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CPU_MOE_DRAFT"));
|
||||
add_opt(common_arg(
|
||||
@@ -2471,7 +2646,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
for (int i = 0; i < value; ++i) {
|
||||
static std::list<std::string> buft_overrides_draft;
|
||||
buft_overrides_draft.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
|
||||
buft_overrides_draft.push_back(llm_ffn_exps_block_regex(i));
|
||||
params.speculative.tensor_buft_overrides.push_back({buft_overrides_draft.back().c_str(), ggml_backend_cpu_buffer_type()});
|
||||
}
|
||||
}
|
||||
@@ -2636,6 +2811,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.model.url = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_MODEL_URL"));
|
||||
add_opt(common_arg(
|
||||
{ "-dr", "--docker-repo" }, "[<repo>/]<model>[:quant]",
|
||||
"Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.\n"
|
||||
"example: gemma3\n"
|
||||
"(default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model.docker_repo = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_DOCKER_REPO"));
|
||||
add_opt(common_arg(
|
||||
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
|
||||
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
|
||||
|
||||
181
common/chat.cpp
181
common/chat.cpp
@@ -631,6 +631,7 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: return "DeepSeek V3.1";
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
|
||||
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
|
||||
@@ -698,11 +699,13 @@ static void parse_json_tool_calls(
|
||||
size_t from = std::string::npos;
|
||||
auto first = true;
|
||||
while (true) {
|
||||
auto start_pos = builder.pos();
|
||||
auto res = function_regex_start_only && first
|
||||
? builder.try_consume_regex(*function_regex_start_only)
|
||||
: function_regex
|
||||
? builder.try_find_regex(*function_regex, from)
|
||||
: std::nullopt;
|
||||
|
||||
if (res) {
|
||||
std::string name;
|
||||
if (get_function_name) {
|
||||
@@ -737,6 +740,8 @@ static void parse_json_tool_calls(
|
||||
return;
|
||||
}
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
} else {
|
||||
builder.move_to(start_pos);
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -1388,6 +1393,71 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_deepseek_v3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// Pass thinking context for DeepSeek V3.1 template
|
||||
json additional_context = {
|
||||
{"thinking", inputs.enable_thinking},
|
||||
};
|
||||
|
||||
auto prompt = apply(tmpl, inputs,
|
||||
/* messages_override= */ inputs.messages,
|
||||
/* tools_override= */ std::nullopt,
|
||||
additional_context);
|
||||
data.prompt = prompt;
|
||||
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
if (string_ends_with(data.prompt, "<think>")) {
|
||||
if (!inputs.enable_thinking) {
|
||||
data.prompt += "</think>";
|
||||
} else {
|
||||
data.thinking_forced_open = true;
|
||||
}
|
||||
}
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
"( \"<|tool▁call▁begin|>\" )? \"" + name + "<|tool▁sep|>"
|
||||
"\" " + builder.add_schema(name + "-args", parameters) + " "
|
||||
"\"<|tool▁call▁end|>\""));
|
||||
});
|
||||
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
|
||||
// so we accept common variants (then it's all constrained)
|
||||
builder.add_rule("root",
|
||||
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
|
||||
"( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" | \"<|tool▁calls|>\" ) "
|
||||
"(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
|
||||
"\"<|tool▁calls▁end|>\""
|
||||
" space");
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
|
||||
// If thinking_forced_open, then we capture the </think> tag in the grammar,
|
||||
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
|
||||
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") +
|
||||
"(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)[\\s\\S]*"
|
||||
});
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<|tool▁calls▁begin|>",
|
||||
"<|tool▁call▁begin|>",
|
||||
"<|tool▁sep|>",
|
||||
"<|tool▁call▁end|>",
|
||||
"<|tool▁calls▁end|>",
|
||||
};
|
||||
});
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
@@ -1409,6 +1479,66 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
|
||||
tool_calls_end);
|
||||
}
|
||||
|
||||
static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) {
|
||||
static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)");
|
||||
|
||||
static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>");
|
||||
static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)");
|
||||
static const common_regex tool_calls_end("<|tool▁calls▁end|>");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
LOG_DBG("%s: not parse_tool_calls\n", __func__);
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_DBG("%s: parse_tool_calls\n", __func__);
|
||||
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
/* block_open= */ tool_calls_begin,
|
||||
/* function_regex_start_only= */ std::nullopt,
|
||||
function_regex,
|
||||
close_regex,
|
||||
tool_calls_end);
|
||||
}
|
||||
|
||||
static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
|
||||
// DeepSeek V3.1 outputs reasoning content between "<think>" and "</think>" tags, followed by regular content
|
||||
// First try to parse using the standard reasoning parsing method
|
||||
LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
|
||||
|
||||
auto start_pos = builder.pos();
|
||||
auto found_end_think = builder.try_find_literal("</think>");
|
||||
builder.move_to(start_pos);
|
||||
|
||||
if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
|
||||
LOG_DBG("%s: no end_think, not partial, adding content\n", __func__);
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
} else if (builder.try_parse_reasoning("<think>", "</think>")) {
|
||||
// If reasoning was parsed successfully, the remaining content is regular content
|
||||
LOG_DBG("%s: parsed reasoning, adding content\n", __func__);
|
||||
// </think><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|>
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
} else {
|
||||
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
|
||||
LOG_DBG("%s: reasoning_format none, adding content\n", __func__);
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
return;
|
||||
}
|
||||
// If no reasoning tags found, check if we should treat everything as reasoning
|
||||
if (builder.syntax().thinking_forced_open) {
|
||||
// If thinking is forced open but no tags found, treat everything as reasoning
|
||||
LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__);
|
||||
builder.add_reasoning_content(builder.consume_rest());
|
||||
} else {
|
||||
LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__);
|
||||
// <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|>
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
auto prompt = apply(tmpl, inputs);
|
||||
@@ -1611,10 +1741,12 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
|
||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ json(), json {
|
||||
const std::optional<json> tools_override = json();
|
||||
const std::optional<json> additional_context = json {
|
||||
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
});
|
||||
};
|
||||
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, tools_override, additional_context);
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
@@ -2100,15 +2232,28 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
|
||||
|
||||
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags
|
||||
static const common_regex start_think_regex(regex_escape("<think>"));
|
||||
static const common_regex end_think_regex(regex_escape("</think>"));
|
||||
// Granite models output partial tokens such as "<" and "<think".
|
||||
// By leveraging try_consume_regex()/try_find_regex() throwing
|
||||
// common_chat_msg_partial_exception for these partial tokens,
|
||||
// processing is interrupted and the tokens are not passed to add_content().
|
||||
if (auto res = builder.try_consume_regex(start_think_regex)) {
|
||||
// Restore position for try_parse_reasoning()
|
||||
builder.move_to(res->groups[0].begin);
|
||||
builder.try_find_regex(end_think_regex, std::string::npos, false);
|
||||
// Restore position for try_parse_reasoning()
|
||||
builder.move_to(res->groups[0].begin);
|
||||
}
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
|
||||
// Parse response tags using regex
|
||||
static const common_regex response_regex("<response>([\\s\\S]*?)</response>");
|
||||
if (auto res = builder.try_find_regex(response_regex)) {
|
||||
// Extract the content between the tags (capture group 1)
|
||||
auto content = builder.str(res->groups[1]);
|
||||
builder.add_content(content);
|
||||
builder.move_to(res->groups[0].end);
|
||||
// Parse response tags
|
||||
static const common_regex start_response_regex(regex_escape("<response>"));
|
||||
static const common_regex end_response_regex(regex_escape("</response>"));
|
||||
// Granite models output partial tokens such as "<" and "<response".
|
||||
// Same hack as reasoning parsing.
|
||||
if (builder.try_consume_regex(start_response_regex)) {
|
||||
builder.try_find_regex(end_response_regex);
|
||||
}
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
@@ -2122,13 +2267,10 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
// Expect JSON array of tool calls
|
||||
auto tool_calls_data = builder.consume_json();
|
||||
if (tool_calls_data.json.is_array()) {
|
||||
if (!builder.add_tool_calls(tool_calls_data.json)) {
|
||||
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
|
||||
if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) {
|
||||
if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
} else {
|
||||
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
@@ -2365,6 +2507,12 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
}
|
||||
}
|
||||
|
||||
// DeepSeek V3.1: detect based on specific patterns in the template
|
||||
if (src.find("message['prefix'] is defined and message['prefix'] and thinking") != std::string::npos &&
|
||||
params.json_schema.is_null()) {
|
||||
return common_chat_params_init_deepseek_v3_1(tmpl, params);
|
||||
}
|
||||
|
||||
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
|
||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) {
|
||||
return common_chat_params_init_deepseek_r1(tmpl, params);
|
||||
@@ -2537,6 +2685,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
|
||||
common_chat_parse_deepseek_r1(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1:
|
||||
common_chat_parse_deepseek_v3_1(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
|
||||
common_chat_parse_functionary_v3_2(builder);
|
||||
break;
|
||||
|
||||
@@ -107,6 +107,7 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||
COMMON_CHAT_FORMAT_GRANITE,
|
||||
|
||||
@@ -193,10 +193,11 @@ struct common_params_sampling {
|
||||
};
|
||||
|
||||
struct common_params_model {
|
||||
std::string path = ""; // model local path // NOLINT
|
||||
std::string url = ""; // model url to download // NOLINT
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
std::string path = ""; // model local path // NOLINT
|
||||
std::string url = ""; // model url to download // NOLINT
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
std::string docker_repo = ""; // Docker repo // NOLINT
|
||||
};
|
||||
|
||||
struct common_params_speculative {
|
||||
@@ -287,9 +288,9 @@ struct common_params {
|
||||
float rope_freq_base = 0.0f; // RoPE base frequency
|
||||
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
||||
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
|
||||
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
|
||||
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||
float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor
|
||||
float yarn_beta_fast = -1.0f; // YaRN low correction dim
|
||||
float yarn_beta_slow = -1.0f; // YaRN high correction dim
|
||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||
|
||||
// offload params
|
||||
@@ -452,7 +453,7 @@ struct common_params {
|
||||
|
||||
std::string slot_save_path;
|
||||
|
||||
float slot_prompt_similarity = 0.5f;
|
||||
float slot_prompt_similarity = 0.1f;
|
||||
|
||||
// batched-bench params
|
||||
bool is_pp_shared = false;
|
||||
@@ -733,6 +734,20 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
// MoE utils
|
||||
//
|
||||
|
||||
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_exps";
|
||||
|
||||
static std::string llm_ffn_exps_block_regex(int idx) {
|
||||
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
|
||||
}
|
||||
|
||||
static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
|
||||
return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() };
|
||||
}
|
||||
|
||||
//
|
||||
// training utils
|
||||
//
|
||||
|
||||
@@ -257,12 +257,13 @@ std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
|
||||
};
|
||||
|
||||
static bool is_reserved_name(const std::string & name) {
|
||||
static std::unordered_set<std::string> RESERVED_NAMES;
|
||||
if (RESERVED_NAMES.empty()) {
|
||||
RESERVED_NAMES.insert("root");
|
||||
for (const auto &p : PRIMITIVE_RULES) RESERVED_NAMES.insert(p.first);
|
||||
for (const auto &p : STRING_FORMAT_RULES) RESERVED_NAMES.insert(p.first);
|
||||
}
|
||||
static const std::unordered_set<std::string> RESERVED_NAMES = [] {
|
||||
std::unordered_set<std::string> s;
|
||||
s.insert("root");
|
||||
for (const auto & p : PRIMITIVE_RULES) s.insert(p.first);
|
||||
for (const auto & p : STRING_FORMAT_RULES) s.insert(p.first);
|
||||
return s;
|
||||
}();
|
||||
return RESERVED_NAMES.find(name) != RESERVED_NAMES.end();
|
||||
}
|
||||
|
||||
@@ -843,9 +844,10 @@ public:
|
||||
_build_object_rule(
|
||||
properties, required, name,
|
||||
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
|
||||
} else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
|
||||
} else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) {
|
||||
std::unordered_set<std::string> required;
|
||||
std::vector<std::pair<std::string, json>> properties;
|
||||
std::map<std::string, size_t> enum_values;
|
||||
std::string hybrid_name = name;
|
||||
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
|
||||
if (comp_schema.contains("$ref")) {
|
||||
@@ -857,6 +859,14 @@ public:
|
||||
required.insert(prop.key());
|
||||
}
|
||||
}
|
||||
} else if (comp_schema.contains("enum")) {
|
||||
for (const auto & v : comp_schema["enum"]) {
|
||||
const auto rule = _generate_constant_rule(v);
|
||||
if (enum_values.find(rule) == enum_values.end()) {
|
||||
enum_values[rule] = 0;
|
||||
}
|
||||
enum_values[rule] += 1;
|
||||
}
|
||||
} else {
|
||||
// todo warning
|
||||
}
|
||||
@@ -870,6 +880,17 @@ public:
|
||||
add_component(t, true);
|
||||
}
|
||||
}
|
||||
if (!enum_values.empty()) {
|
||||
std::vector<std::string> enum_intersection;
|
||||
for (const auto & p : enum_values) {
|
||||
if (p.second == schema["allOf"].size()) {
|
||||
enum_intersection.push_back(p.first);
|
||||
}
|
||||
}
|
||||
if (!enum_intersection.empty()) {
|
||||
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
|
||||
}
|
||||
}
|
||||
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
||||
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
|
||||
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
|
||||
|
||||
@@ -735,6 +735,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
|
||||
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
|
||||
res = "qwen2"
|
||||
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
|
||||
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
|
||||
res = "grok-2"
|
||||
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
||||
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
||||
res = "llama-bpe"
|
||||
@@ -885,6 +888,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
|
||||
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
|
||||
res = "mellum"
|
||||
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
|
||||
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
|
||||
res = "llada-moe"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
@@ -2387,7 +2393,10 @@ class SmolVLMModel(MmprojModel):
|
||||
return [] # skip other tensors
|
||||
|
||||
|
||||
@ModelBase.register("Llama4ForConditionalGeneration")
|
||||
@ModelBase.register(
|
||||
"Llama4ForConditionalGeneration",
|
||||
"Llama4ForCausalLM",
|
||||
)
|
||||
class Llama4Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA4
|
||||
undo_permute = False
|
||||
@@ -2405,6 +2414,10 @@ class Llama4Model(LlamaModel):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["interleave_moe_layer_step"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
|
||||
if "layer_types" in self.hparams:
|
||||
if all(lt == "full_attention" for lt in self.hparams["layer_types"]):
|
||||
# all layers are full attention (for MobileLLM), disable swa
|
||||
self.gguf_writer.add_sliding_window(0)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
if name.startswith("language_model."):
|
||||
@@ -2682,12 +2695,20 @@ class BitnetModel(TextModel):
|
||||
yield (new_name, data_torch)
|
||||
|
||||
|
||||
@ModelBase.register("GrokForCausalLM")
|
||||
@ModelBase.register("GrokForCausalLM", "Grok1ForCausalLM")
|
||||
class GrokModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GROK
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
if (self.dir_model / 'tokenizer.model').is_file():
|
||||
self._set_vocab_sentencepiece()
|
||||
return
|
||||
|
||||
if not (self.dir_model / 'tokenizer.json').is_file() or not (self.dir_model / 'chat_template.jinja').is_file():
|
||||
logger.error('Error: Missing vocab and chat template, download files from https://huggingface.co/alvarobartt/grok-2-tokenizer')
|
||||
sys.exit(1)
|
||||
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -2695,11 +2716,46 @@ class GrokModel(TextModel):
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
self.gguf_writer.add_attn_logit_softcapping(self.hparams.get("attn_logit_softcapping", 30.0))
|
||||
self.gguf_writer.add_router_logit_softcapping(self.hparams.get("router_logit_softcapping", 30.0))
|
||||
if (final_logit_softcap := self.hparams.get("final_logit_softcapping")):
|
||||
self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
|
||||
|
||||
if (rope_dim := self.hparams.get("head_dim")) is None:
|
||||
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
|
||||
# Treat "original" as "yarn", seems to have been a mistake
|
||||
if self.hparams.get("rope_type") in ("yarn", "original"):
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(self.hparams["scaling_factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["original_max_position_embeddings"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_ext_factor(self.hparams["extrapolation_factor"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_attn_factor(self.hparams["attn_factor"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_fast(self.hparams["beta_fast"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_slow(self.hparams["beta_slow"])
|
||||
|
||||
if temp_len := self.hparams.get("attn_temperature_len"):
|
||||
self.gguf_writer.add_attn_temperature_length(temp_len)
|
||||
|
||||
self.gguf_writer.add_attn_output_scale(self.hparams.get("attn_output_multiplier", rope_dim**-0.5))
|
||||
self.gguf_writer.add_embedding_scale(self.hparams["embedding_multiplier_scale"])
|
||||
self.gguf_writer.add_logit_scale(self.hparams["output_multiplier_scale"])
|
||||
|
||||
_experts: list[dict[str, list[Tensor]]] | None = None
|
||||
_cur_expert = ""
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
is_expert = ".moe." in name or ".block_sparse_moe.experts." in name
|
||||
|
||||
if not is_expert:
|
||||
tensors.append((self.map_tensor_name(name), data_torch))
|
||||
|
||||
# process the experts separately
|
||||
if name.find(".moe.") != -1:
|
||||
if is_expert or self._cur_expert:
|
||||
n_experts = self.hparams["num_local_experts"]
|
||||
|
||||
assert bid is not None
|
||||
@@ -2707,32 +2763,41 @@ class GrokModel(TextModel):
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for wid in ["linear", "linear_1", "linear_v"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
# concatenate split tensors
|
||||
if name in self._experts[bid]:
|
||||
self._cur_expert = name
|
||||
self._experts[bid][name].append(data_torch)
|
||||
return []
|
||||
elif is_expert:
|
||||
self._cur_expert = name
|
||||
self._experts[bid][name] = [data_torch]
|
||||
return []
|
||||
else:
|
||||
self._cur_expert = ""
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
for bid in range(self.block_count):
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
# merge the experts into a single 3d tensor
|
||||
for wid in [("linear", "w1", 0), ("linear_1", "w2", 1), ("linear_v", "w3", 0)]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid[0]}.weight"
|
||||
if ename not in self._experts[bid]:
|
||||
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid[1]}.weight"
|
||||
tensor_list = self._experts[bid][ename]
|
||||
datas.append(torch.cat(tensor_list, dim=wid[2]) if len(tensor_list) > 1 else tensor_list[0])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid[0]}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
yield (new_name, data_torch)
|
||||
|
||||
yield from tensors
|
||||
|
||||
|
||||
@ModelBase.register("DbrxForCausalLM")
|
||||
@@ -5128,6 +5193,20 @@ class EmbeddingGemma(Gemma3Model):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Override the sliding window size as it gets adjusted by the Gemma3TextConfig
|
||||
# constructor. We want to use the value from the original model's config.json.
|
||||
# ref: https://github.com/huggingface/transformers/pull/40700
|
||||
with open(self.dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
orig_sliding_window = config.get("sliding_window")
|
||||
if orig_sliding_window is None:
|
||||
raise ValueError("sliding_window not found in model config - this is required for the model")
|
||||
|
||||
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
|
||||
f"instead of {self.hparams['sliding_window']}")
|
||||
self.gguf_writer.add_sliding_window(orig_sliding_window)
|
||||
|
||||
self._try_set_pooling_type()
|
||||
|
||||
|
||||
@@ -5937,9 +6016,34 @@ class SeedOssModel(TextModel):
|
||||
|
||||
|
||||
@ModelBase.register("Olmo2ForCausalLM")
|
||||
@ModelBase.register("Olmo3ForCausalLM")
|
||||
class Olmo2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.OLMO2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||
self.gguf_writer.add_rope_scaling_attn_factors(rope_scaling["attention_factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||
|
||||
if "sliding_window" in self.hparams:
|
||||
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
|
||||
|
||||
sliding_window_pattern = []
|
||||
if "layer_types" in self.hparams:
|
||||
sliding_window_pattern = [t == "sliding_attention" for t in self.hparams["layer_types"]]
|
||||
else:
|
||||
# Olmo2 does not use sliding window attention.
|
||||
# Olmo3 defaults to using sliding window for all layers except every 4th.
|
||||
for i in range(self.hparams["num_hidden_layers"]):
|
||||
sliding_window_pattern.append((i + 1) % 4 != 0)
|
||||
|
||||
self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
|
||||
|
||||
|
||||
@ModelBase.register("OlmoeForCausalLM")
|
||||
class OlmoeModel(TextModel):
|
||||
@@ -6687,6 +6791,8 @@ class T5Model(TextModel):
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
||||
if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None:
|
||||
self.gguf_writer.add_decoder_block_count(dec_n_layer)
|
||||
self.gguf_writer.add_head_count(self.hparams["num_heads"])
|
||||
self.gguf_writer.add_key_length(self.hparams["d_kv"])
|
||||
self.gguf_writer.add_value_length(self.hparams["d_kv"])
|
||||
@@ -8168,6 +8274,76 @@ class HunYuanMoEModel(TextModel):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("LLaDAMoEModel", "LLaDAMoEModelLM")
|
||||
class LLaDAMoEModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLADA_MOE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
|
||||
if (expert_intermediate_size := self.hparams.get("expert_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size)
|
||||
|
||||
# number of experts used per token (top-k)
|
||||
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
|
||||
self.gguf_writer.add_mask_token_id(156895)
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
self.gguf_writer.add_diffusion_shift_logits(False)
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
# Copied from: Qwen2MoeModel
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
# Copied from: Qwen2MoeModel
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("HunYuanDenseV1ForCausalLM")
|
||||
class HunYuanModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE
|
||||
|
||||
@@ -139,6 +139,7 @@ models = [
|
||||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
@@ -158,6 +159,7 @@ pre_computed_hashes = [
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
|
||||
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
|
||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
|
||||
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -314,3 +314,11 @@ Converting the matmul weight format from ND to NZ to improve performance. Enable
|
||||
### GGML_CANN_ACL_GRAPH
|
||||
|
||||
Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default.
|
||||
|
||||
### GGML_CANN_GRAPH_CACHE_CAPACITY
|
||||
|
||||
Maximum number of compiled CANN graphs kept in the LRU cache, default is 12. When the number of cached graphs exceeds this capacity, the least recently used graph will be evicted.
|
||||
|
||||
### GGML_CANN_PREFILL_USE_GRAPH
|
||||
|
||||
Enable ACL graph execution during the prefill stage, default is false. This option is only effective when FA is enabled.
|
||||
|
||||
@@ -241,8 +241,8 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
| | VX/VXE/VXE2 | zDNN | Spyre |
|
||||
|------------|-------------|------|-------|
|
||||
| FP32 | ✅ | ✅ | ❓ |
|
||||
| FP16 | ✅ | ❓ | ❓ |
|
||||
| BF16 | 🚫 | ❓ | ❓ |
|
||||
| FP16 | ✅ | ✅ | ❓ |
|
||||
| BF16 | 🚫 | ✅ | ❓ |
|
||||
| Q4_0 | ✅ | ❓ | ❓ |
|
||||
| Q4_1 | ✅ | ❓ | ❓ |
|
||||
| MXFP4 | 🚫 | ❓ | ❓ |
|
||||
@@ -272,4 +272,4 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
- 🚫 - acceleration unavailable, will still run using scalar implementation
|
||||
- ❓ - acceleration unknown, please contribute if you can test it yourself
|
||||
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 6, 2025.
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 7, 2025.
|
||||
|
||||
@@ -18,6 +18,7 @@ Legend:
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
@@ -26,6 +27,7 @@ Legend:
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
@@ -49,9 +51,11 @@ Legend:
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
@@ -61,7 +65,9 @@ Legend:
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
@@ -98,6 +104,7 @@ Legend:
|
||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
|
||||
11114
docs/ops/zDNN.csv
11114
docs/ops/zDNN.csv
File diff suppressed because it is too large
Load Diff
@@ -510,19 +510,27 @@ static void diffusion_generate(llama_context * ctx,
|
||||
n_generated = params.max_length;
|
||||
}
|
||||
|
||||
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
|
||||
static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) {
|
||||
if (!use_chat_template) {
|
||||
return prompt;
|
||||
}
|
||||
|
||||
auto chat_templates = common_chat_templates_init(model, "");
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
common_chat_msg user_msg;
|
||||
user_msg.role = "user";
|
||||
user_msg.content = prompt;
|
||||
inputs.add_generation_prompt = true;
|
||||
common_chat_msg system_msg;
|
||||
|
||||
if (!system_prompt.empty()) {
|
||||
system_msg.role = "system";
|
||||
system_msg.content = system_prompt;
|
||||
inputs.messages.push_back(system_msg);
|
||||
}
|
||||
|
||||
common_chat_msg user_msg;
|
||||
user_msg.role = "user";
|
||||
user_msg.content = prompt;
|
||||
|
||||
inputs.messages.push_back(user_msg);
|
||||
inputs.add_generation_prompt = true;
|
||||
|
||||
auto result = common_chat_templates_apply(chat_templates.get(), inputs);
|
||||
|
||||
@@ -579,7 +587,8 @@ int main(int argc, char ** argv) {
|
||||
llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
|
||||
|
||||
std::string formatted_prompt = format_input_text(params.prompt, params.system_prompt, params.enable_chat_template, model);
|
||||
|
||||
std::vector<llama_token> input_tokens = common_tokenize(vocab,
|
||||
formatted_prompt,
|
||||
@@ -596,6 +605,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
llama_token mask_token_id = llama_vocab_mask(vocab);
|
||||
|
||||
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
|
||||
|
||||
bool visual_mode = params.diffusion.visual_mode;
|
||||
|
||||
@@ -28,6 +28,15 @@ static std::string ggml_ne_string(const ggml_tensor * t) {
|
||||
return str;
|
||||
}
|
||||
|
||||
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t i;
|
||||
} u;
|
||||
u.i = (uint32_t)h.bits << 16;
|
||||
return u.f;
|
||||
}
|
||||
|
||||
static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
|
||||
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
|
||||
float v;
|
||||
@@ -43,6 +52,8 @@ static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t *
|
||||
v = (float) *(int16_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I8) {
|
||||
v = (float) *(int8_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_BF16) {
|
||||
v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
@@ -586,9 +586,10 @@ class SchemaConverter:
|
||||
properties = list(schema.get('properties', {}).items())
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
|
||||
|
||||
elif schema_type in (None, 'object') and 'allOf' in schema:
|
||||
elif schema_type in (None, 'object', 'string') and 'allOf' in schema:
|
||||
required = set()
|
||||
properties = []
|
||||
enum_sets = []
|
||||
hybrid_name = name
|
||||
def add_component(comp_schema, is_required):
|
||||
if (ref := comp_schema.get('$ref')) is not None:
|
||||
@@ -600,6 +601,9 @@ class SchemaConverter:
|
||||
if is_required:
|
||||
required.add(prop_name)
|
||||
|
||||
if 'enum' in comp_schema:
|
||||
enum_sets.append(set(comp_schema['enum']))
|
||||
|
||||
for t in schema['allOf']:
|
||||
if 'anyOf' in t:
|
||||
for tt in t['anyOf']:
|
||||
@@ -607,6 +611,15 @@ class SchemaConverter:
|
||||
else:
|
||||
add_component(t, is_required=True)
|
||||
|
||||
if enum_sets:
|
||||
enum_intersection = enum_sets[0]
|
||||
for s in enum_sets[1:]:
|
||||
enum_intersection &= s
|
||||
|
||||
if enum_intersection:
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
||||
|
||||
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch~=2.6.0
|
||||
torchvision~=0.21.0
|
||||
transformers~=4.55.0
|
||||
huggingface-hub~=0.34.0
|
||||
torch
|
||||
torchvision
|
||||
transformers
|
||||
huggingface-hub
|
||||
accelerate
|
||||
|
||||
@@ -9,15 +9,134 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
|
||||
### If you want to dump RoPE activations, apply this monkey patch to the model
|
||||
### class from Transformers that you are running (replace apertus.modeling_apertus
|
||||
### with the proper package and class for your model
|
||||
### === START ROPE DEBUG ===
|
||||
# from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
|
||||
|
||||
parser = argparse.ArgumentParser(description='Process model with specified path')
|
||||
parser.add_argument('--model-path', '-m', help='Path to the model')
|
||||
# orig_rope = apply_rotary_pos_emb
|
||||
# torch.set_printoptions(threshold=float('inf'))
|
||||
# torch.set_printoptions(precision=6, sci_mode=False)
|
||||
|
||||
# def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
# # log inputs
|
||||
# summarize(q, "RoPE.q_in")
|
||||
# summarize(k, "RoPE.k_in")
|
||||
|
||||
# # call original
|
||||
# q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
|
||||
|
||||
# # log outputs
|
||||
# summarize(q_out, "RoPE.q_out")
|
||||
# summarize(k_out, "RoPE.k_out")
|
||||
|
||||
# return q_out, k_out
|
||||
|
||||
# # Patch it
|
||||
# import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
|
||||
# apertus_mod.apply_rotary_pos_emb = debug_rope
|
||||
### == END ROPE DEBUG ===
|
||||
|
||||
|
||||
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
|
||||
"""
|
||||
Print a tensor in llama.cpp debug style.
|
||||
|
||||
Supports:
|
||||
- 2D tensors (seq, hidden)
|
||||
- 3D tensors (batch, seq, hidden)
|
||||
- 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
|
||||
|
||||
Shows first and last max_vals of each vector per sequence position.
|
||||
"""
|
||||
t = tensor.detach().to(torch.float32).cpu()
|
||||
|
||||
# Determine dimensions
|
||||
if t.ndim == 3:
|
||||
_, s, _ = t.shape
|
||||
elif t.ndim == 2:
|
||||
_, s = 1, t.shape[0]
|
||||
t = t.unsqueeze(0)
|
||||
elif t.ndim == 4:
|
||||
_, s, _, _ = t.shape
|
||||
else:
|
||||
print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
|
||||
return
|
||||
|
||||
ten_shape = t.shape
|
||||
|
||||
print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
|
||||
print(" [")
|
||||
print(" [")
|
||||
|
||||
# Determine indices for first and last sequences
|
||||
first_indices = list(range(min(s, max_seq)))
|
||||
last_indices = list(range(max(0, s - max_seq), s))
|
||||
|
||||
# Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
|
||||
has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
|
||||
|
||||
# Combine indices
|
||||
if has_overlap:
|
||||
# If there's overlap, just use the combined unique indices
|
||||
indices = sorted(list(set(first_indices + last_indices)))
|
||||
separator_index = None
|
||||
else:
|
||||
# If no overlap, we'll add a separator between first and last sequences
|
||||
indices = first_indices + last_indices
|
||||
separator_index = len(first_indices)
|
||||
|
||||
for i, si in enumerate(indices):
|
||||
# Add separator if needed
|
||||
if separator_index is not None and i == separator_index:
|
||||
print(" ...")
|
||||
|
||||
# Extract appropriate slice
|
||||
vec = t[0, si]
|
||||
if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
|
||||
flat = vec.flatten().tolist()
|
||||
else: # 2D or 3D case
|
||||
flat = vec.tolist()
|
||||
|
||||
# First and last slices
|
||||
first = flat[:max_vals]
|
||||
last = flat[-max_vals:] if len(flat) >= max_vals else flat
|
||||
first_str = ", ".join(f"{v:12.4f}" for v in first)
|
||||
last_str = ", ".join(f"{v:12.4f}" for v in last)
|
||||
|
||||
print(f" [{first_str}, ..., {last_str}]")
|
||||
|
||||
print(" ],")
|
||||
print(" ]")
|
||||
print(f" sum = {t.sum().item():.6f}\n")
|
||||
|
||||
|
||||
def debug_hook(name):
|
||||
def fn(_m, input, output):
|
||||
if isinstance(input, torch.Tensor):
|
||||
summarize(input, name + "_in")
|
||||
elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor):
|
||||
summarize(input[0], name + "_in")
|
||||
if isinstance(output, torch.Tensor):
|
||||
summarize(output, name + "_out")
|
||||
elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor):
|
||||
summarize(output[0], name + "_out")
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Process model with specified path")
|
||||
parser.add_argument("--model-path", "-m", help="Path to the model")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = os.environ.get('MODEL_PATH', args.model_path)
|
||||
model_path = os.environ.get("MODEL_PATH", args.model_path)
|
||||
if model_path is None:
|
||||
parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
|
||||
parser.error(
|
||||
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
|
||||
@@ -34,18 +153,30 @@ config = AutoConfig.from_pretrained(model_path)
|
||||
|
||||
if unreleased_model_name:
|
||||
model_name_lower = unreleased_model_name.lower()
|
||||
unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
|
||||
unreleased_module_path = (
|
||||
f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
|
||||
)
|
||||
class_name = f"{unreleased_model_name}ForCausalLM"
|
||||
print(f"Importing unreleased model module: {unreleased_module_path}")
|
||||
|
||||
try:
|
||||
model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
|
||||
model = model_class.from_pretrained(model_path) # Note: from_pretrained, not fromPretrained
|
||||
model_class = getattr(
|
||||
importlib.import_module(unreleased_module_path), class_name
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
model_path
|
||||
) # Note: from_pretrained, not fromPretrained
|
||||
except (ImportError, AttributeError) as e:
|
||||
print(f"Failed to import or load model: {e}")
|
||||
exit(1)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, device_map="auto", offload_folder="offload"
|
||||
)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if len(list(module.children())) == 0: # only leaf modules
|
||||
module.register_forward_hook(debug_hook(name))
|
||||
|
||||
model_name = os.path.basename(model_path)
|
||||
# Printing the Model class to allow for easier debugging. This can be useful
|
||||
|
||||
@@ -145,6 +145,20 @@ int main(int argc, char ** argv) {
|
||||
|
||||
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
if (llama_encode(ctx, batch)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
||||
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
|
||||
decoder_start_token_id = llama_vocab_bos(vocab);
|
||||
}
|
||||
|
||||
batch = llama_batch_get_one(&decoder_start_token_id, 1);
|
||||
}
|
||||
|
||||
// main loop
|
||||
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
@@ -190,7 +190,6 @@ option(GGML_WEBGPU "ggml: use WebGPU"
|
||||
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
|
||||
option(GGML_ZDNN "ggml: use zDNN" OFF)
|
||||
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
||||
option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
|
||||
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
||||
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
|
||||
option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})
|
||||
|
||||
@@ -132,6 +132,8 @@ extern "C" {
|
||||
GGML_BACKEND_DEVICE_TYPE_CPU,
|
||||
// GPU device using dedicated memory
|
||||
GGML_BACKEND_DEVICE_TYPE_GPU,
|
||||
// integrated GPU device using host memory
|
||||
GGML_BACKEND_DEVICE_TYPE_IGPU,
|
||||
// accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
|
||||
GGML_BACKEND_DEVICE_TYPE_ACCEL
|
||||
};
|
||||
@@ -150,11 +152,21 @@ extern "C" {
|
||||
|
||||
// all the device properties
|
||||
struct ggml_backend_dev_props {
|
||||
// device name
|
||||
const char * name;
|
||||
// device description
|
||||
const char * description;
|
||||
// device free memory in bytes
|
||||
size_t memory_free;
|
||||
// device total memory in bytes
|
||||
size_t memory_total;
|
||||
// device type
|
||||
enum ggml_backend_dev_type type;
|
||||
// device id
|
||||
// for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0")
|
||||
// if the id is unknown, this should be NULL
|
||||
const char * device_id;
|
||||
// device capabilities
|
||||
struct ggml_backend_dev_caps caps;
|
||||
};
|
||||
|
||||
|
||||
@@ -134,6 +134,7 @@ extern "C" {
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
|
||||
|
||||
GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
|
||||
GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t);
|
||||
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
|
||||
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
|
||||
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
|
||||
|
||||
@@ -39,18 +39,13 @@ extern "C" {
|
||||
// user-code should use only these functions
|
||||
//
|
||||
|
||||
// TODO: remove in the future
|
||||
GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);
|
||||
|
||||
GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
||||
|
||||
GGML_DEPRECATED(
|
||||
GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
|
||||
"obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
|
||||
|
||||
GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
|
||||
|
||||
// helper to check if the device supports a specific family
|
||||
// ideally, the user code should be doing these checks
|
||||
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
GGML_BACKEND_API ggml_backend_t ggml_backend_zdnn_init(void);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zdnn_reg(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
@@ -284,19 +284,19 @@ __host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexc
|
||||
// GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
|
||||
//
|
||||
#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
|
||||
const type prefix##0 = (pointer)->array[0]; \
|
||||
const type prefix##0 = (pointer) ? (pointer)->array[0] : 0; \
|
||||
GGML_UNUSED(prefix##0);
|
||||
#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
|
||||
GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \
|
||||
const type prefix##1 = (pointer)->array[1]; \
|
||||
const type prefix##1 = (pointer) ? (pointer)->array[1] : 0; \
|
||||
GGML_UNUSED(prefix##1);
|
||||
#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
|
||||
GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \
|
||||
const type prefix##2 = (pointer)->array[2]; \
|
||||
const type prefix##2 = (pointer) ? (pointer)->array[2] : 0; \
|
||||
GGML_UNUSED(prefix##2);
|
||||
#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
|
||||
GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \
|
||||
const type prefix##3 = (pointer)->array[3]; \
|
||||
const type prefix##3 = (pointer) ? (pointer)->array[3] : 0; \
|
||||
GGML_UNUSED(prefix##3);
|
||||
|
||||
#define GGML_TENSOR_UNARY_OP_LOCALS \
|
||||
@@ -1404,6 +1404,7 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// note: casting from f32 to i32 will discard the fractional part
|
||||
GGML_API struct ggml_tensor * ggml_cast(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@@ -1528,7 +1529,11 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// supports 3D: a->ne[2] == b->ne[1]
|
||||
// supports 4D a:
|
||||
// a [n_embd, ne1, ne2, ne3]
|
||||
// b I32 [n_rows, ne2, ne3, 1]
|
||||
//
|
||||
// return [n_embd, n_rows, ne2, ne3]
|
||||
GGML_API struct ggml_tensor * ggml_get_rows(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // data
|
||||
|
||||
@@ -114,6 +114,9 @@ message(STATUS "GGML_SYSTEM_ARCH: ${GGML_SYSTEM_ARCH}")
|
||||
|
||||
if (NOT MSVC)
|
||||
if (GGML_STATIC)
|
||||
if (UNIX AND NOT APPLE)
|
||||
set(CMAKE_FIND_LIBRARY_SUFFIXES ".a;.so")
|
||||
endif()
|
||||
add_link_options(-static)
|
||||
if (MINGW)
|
||||
add_link_options(-static-libgcc -static-libstdc++)
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define GGML_BACKEND_API_VERSION 1
|
||||
#define GGML_BACKEND_API_VERSION 2
|
||||
|
||||
//
|
||||
// Backend buffer type
|
||||
@@ -114,6 +114,9 @@ extern "C" {
|
||||
void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event);
|
||||
// wait for an event on on a different stream
|
||||
void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event);
|
||||
|
||||
// (optional) sort/optimize the nodes in the graph
|
||||
void (*graph_optimize) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||
};
|
||||
|
||||
struct ggml_backend {
|
||||
|
||||
@@ -400,9 +400,8 @@ ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const
|
||||
|
||||
ggml_backend_t ggml_backend_init_best(void) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
|
||||
if (!dev) {
|
||||
dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
}
|
||||
dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU);
|
||||
dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!dev) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -463,6 +463,13 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
|
||||
backend->iface.event_wait(backend, event);
|
||||
}
|
||||
|
||||
static void ggml_backend_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||
GGML_ASSERT(backend);
|
||||
if (backend->iface.graph_optimize != NULL) {
|
||||
backend->iface.graph_optimize(backend, cgraph);
|
||||
}
|
||||
}
|
||||
|
||||
// Backend device
|
||||
|
||||
const char * ggml_backend_dev_name(ggml_backend_dev_t device) {
|
||||
@@ -1298,6 +1305,10 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
|
||||
struct ggml_backend_sched_split * split = &sched->splits[i];
|
||||
split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
|
||||
|
||||
// Optimize this split of the graph. This needs to happen before we make graph_copy,
|
||||
// so they are in sync.
|
||||
ggml_backend_graph_optimize(sched->backends[split->backend_id], &split->graph);
|
||||
|
||||
// add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
|
||||
for (int j = 0; j < split->n_inputs; j++) {
|
||||
assert(graph_copy->size > (graph_copy->n_nodes + 1));
|
||||
|
||||
@@ -270,6 +270,7 @@ static struct ggml_backend_i blas_backend_i = {
|
||||
/* .graph_compute = */ ggml_backend_blas_graph_compute,
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_blas_guid(void) {
|
||||
|
||||
@@ -2268,8 +2268,6 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
|
||||
* stream, and persistent buffers for rope init/cache.
|
||||
* @param dst The destination ggml_tensor whose computation
|
||||
* depends on the RoPE values (usually Qcur/Kcur).
|
||||
* @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values.
|
||||
* @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values.
|
||||
* @param theta_scale Scalar exponent base for computing theta scale values.
|
||||
* @param freq_scale Frequency scaling factor, applied to theta scale.
|
||||
* @param attn_factor Attention scaling factor, applied to sin/cos.
|
||||
@@ -2277,17 +2275,23 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
|
||||
* (dim expansion vs repeat_interleave).
|
||||
*/
|
||||
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
void* sin_tensor_buffer, void* cos_tensor_buffer,
|
||||
float* corr_dims, float ext_factor,
|
||||
float theta_scale, float freq_scale,
|
||||
float attn_factor, bool is_neox) {
|
||||
// int sin/cos cache, cache has different repeat method depond on
|
||||
// @param.is_neox
|
||||
|
||||
ggml_tensor* src0 = dst->src[0]; // input
|
||||
ggml_tensor* src1 = dst->src[1]; // position
|
||||
ggml_tensor* src2 = dst->src[2]; // freq_factors
|
||||
|
||||
if(src2 == nullptr && ctx.rope_cache.cached
|
||||
&& ctx.rope_cache.ext_factor == ext_factor
|
||||
&& ctx.rope_cache.theta_scale == theta_scale
|
||||
&& ctx.rope_cache.freq_scale == freq_scale
|
||||
&& ctx.rope_cache.attn_factor == attn_factor
|
||||
&& ctx.rope_cache.is_neox == is_neox) {
|
||||
// use cache.
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t theta_scale_length = src0->ne[0] / 2;
|
||||
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
|
||||
size_t theta_scale_nb[] = {sizeof(float), sizeof(float), sizeof(float),
|
||||
@@ -2316,8 +2320,6 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
ctx.rope_cache.freq_scale != freq_scale) {
|
||||
|
||||
ctx.rope_cache.theta_scale_length = theta_scale_length;
|
||||
ctx.rope_cache.theta_scale = theta_scale;
|
||||
ctx.rope_cache.freq_scale = freq_scale;
|
||||
|
||||
if (ctx.rope_cache.theta_scale_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
|
||||
@@ -2342,7 +2344,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
// return MIN(1, MAX(0, y)) - 1;
|
||||
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
|
||||
void* yarn_ramp_buffer = yarn_ramp_allocator.get();
|
||||
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float),
|
||||
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||
float zero_value = 0, one_value = 1;
|
||||
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
|
||||
@@ -2411,6 +2413,20 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
|
||||
}
|
||||
|
||||
// init sin_repeat && cos_repeat, only to accelerate first layer on each device
|
||||
if (position_length > ctx.rope_cache.position_length) {
|
||||
ctx.rope_cache.position_length = position_length;
|
||||
if (ctx.rope_cache.sin_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(ctx.rope_cache.sin_cache));
|
||||
}
|
||||
if (ctx.rope_cache.cos_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(ctx.rope_cache.cos_cache));
|
||||
}
|
||||
int64_t repeat_theta_length = theta_scale_length * position_length * 2;
|
||||
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
}
|
||||
|
||||
// position
|
||||
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
|
||||
src1->data, ggml_cann_type_mapping(src1->type),
|
||||
@@ -2462,10 +2478,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_sin_repeat_tensor =
|
||||
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float),
|
||||
ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
aclTensor* acl_cos_repeat_tensor =
|
||||
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float),
|
||||
ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
|
||||
// repeat
|
||||
@@ -2483,6 +2499,14 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
num_repeats, output_size);
|
||||
}
|
||||
|
||||
// Other layers use cache except first layer.
|
||||
ctx.rope_cache.cached = true;
|
||||
ctx.rope_cache.ext_factor = ext_factor;
|
||||
ctx.rope_cache.theta_scale = theta_scale;
|
||||
ctx.rope_cache.freq_scale = freq_scale;
|
||||
ctx.rope_cache.attn_factor = attn_factor;
|
||||
ctx.rope_cache.is_neox = is_neox;
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
|
||||
acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
|
||||
acl_cos_repeat_tensor);
|
||||
@@ -2504,10 +2528,7 @@ aclnnStatus aclnnRotaryPositionEmbedding(void* workspace,
|
||||
#endif
|
||||
|
||||
void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
// TODO: use ascendc
|
||||
// Only test with LLAMA model.
|
||||
ggml_tensor* src0 = dst->src[0]; // input
|
||||
ggml_tensor* src1 = dst->src[1];
|
||||
|
||||
// param
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
@@ -2538,15 +2559,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
// sin/cos tensor length.
|
||||
int64_t repeat_theta_length = src0->ne[0] * src1->ne[0];
|
||||
ggml_cann_pool_alloc sin_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
|
||||
ggml_cann_pool_alloc cos_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
|
||||
void *sin_tensor_buffer = sin_tensor_allocator.get();
|
||||
void *cos_tensor_buffer = cos_tensor_allocator.get();
|
||||
|
||||
// init ctx.rope_cos/rope_sin cache
|
||||
aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor,
|
||||
aclnn_cache_init(ctx, dst, corr_dims, ext_factor,
|
||||
theta_scale, freq_scale, attn_factor, is_neox);
|
||||
|
||||
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
|
||||
@@ -2556,10 +2570,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_sin_reshape_tensor =
|
||||
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float),
|
||||
ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
aclTensor* acl_cos_reshape_tensor =
|
||||
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float),
|
||||
ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
|
||||
aclTensor* acl_src = ggml_cann_create_tensor(src0);
|
||||
|
||||
@@ -38,6 +38,7 @@
|
||||
#include <unistd.h>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <list>
|
||||
|
||||
#include "../include/ggml-cann.h"
|
||||
#include "../include/ggml.h"
|
||||
@@ -106,6 +107,7 @@ int32_t ggml_cann_get_device();
|
||||
|
||||
std::optional<std::string> get_env(const std::string& name);
|
||||
bool parse_bool(const std::string& value);
|
||||
int parse_integer(const std::string& value);
|
||||
|
||||
/**
|
||||
* @brief Abstract base class for memory pools used by CANN.
|
||||
@@ -350,7 +352,7 @@ struct ggml_graph_node_properties {
|
||||
struct ggml_cann_graph {
|
||||
~ggml_cann_graph() {
|
||||
if (graph != nullptr) {
|
||||
aclmdlRIDestroy(graph);
|
||||
ACL_CHECK(aclmdlRIDestroy(graph));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -358,6 +360,64 @@ struct ggml_cann_graph {
|
||||
|
||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief LRU cache for managing ggml_cann_graph objects.
|
||||
*
|
||||
* This class maintains a list of shared_ptr to ggml_cann_graph objects
|
||||
* and enforces a maximum capacity. It provides methods to push new graphs,
|
||||
* move existing graphs to the front (most recently used), and clear the cache.
|
||||
*/
|
||||
struct ggml_cann_graph_lru_cache {
|
||||
size_t capacity; /**< Maximum number of graphs in the cache. */
|
||||
|
||||
std::list<ggml_cann_graph*> cache_list; /**< List storing cached graphs as raw pointers. */
|
||||
|
||||
ggml_cann_graph_lru_cache() {
|
||||
capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12"));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Push a new graph to the front of the cache.
|
||||
* If the cache exceeds capacity, the least recently used graph is deleted.
|
||||
* @param new_node Pointer to the new ggml_cann_graph to cache.
|
||||
* Ownership is transferred to the cache (cache will delete it).
|
||||
*/
|
||||
void push(ggml_cann_graph* new_node) {
|
||||
if (cache_list.size() >= capacity) {
|
||||
ggml_cann_graph* old = cache_list.back();
|
||||
cache_list.pop_back();
|
||||
delete old; // free the old graph
|
||||
}
|
||||
cache_list.push_front(new_node);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Move an existing graph to the front of the cache.
|
||||
* @param node Pointer to the ggml_cann_graph to move.
|
||||
*/
|
||||
void move_to_front(ggml_cann_graph* node) {
|
||||
cache_list.remove(node);
|
||||
cache_list.push_front(node);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Clear all graphs from the cache (also frees memory).
|
||||
*/
|
||||
void clear() {
|
||||
for (auto ptr : cache_list) {
|
||||
delete ptr;
|
||||
}
|
||||
cache_list.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Destructor that clears the cache and frees all cached graphs.
|
||||
*/
|
||||
~ggml_cann_graph_lru_cache() {
|
||||
clear();
|
||||
}
|
||||
};
|
||||
#endif // USE_ACL_GRAPH
|
||||
|
||||
struct ggml_cann_rope_cache {
|
||||
@@ -365,12 +425,27 @@ struct ggml_cann_rope_cache {
|
||||
if(theta_scale_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(theta_scale_cache));
|
||||
}
|
||||
if(sin_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(sin_cache));
|
||||
}
|
||||
if(cos_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(cos_cache));
|
||||
}
|
||||
}
|
||||
|
||||
void* theta_scale_cache = nullptr;
|
||||
int64_t theta_scale_length = 0;
|
||||
// sin/cos cache, used only to accelerate first layer on each device
|
||||
void* sin_cache = nullptr;
|
||||
void* cos_cache = nullptr;
|
||||
int64_t position_length = 0;
|
||||
// Properties to check before reusing the sincos cache
|
||||
bool cached = false;
|
||||
float ext_factor = 0.0f;
|
||||
float theta_scale = 0.0f;
|
||||
float freq_scale = 0.0f;
|
||||
float attn_factor = 0.0f;
|
||||
bool is_neox = false;
|
||||
};
|
||||
|
||||
struct ggml_cann_tensor_cache {
|
||||
@@ -394,7 +469,7 @@ struct ggml_backend_cann_context {
|
||||
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
|
||||
#ifdef USE_ACL_GRAPH
|
||||
/// Cached CANN ACL graph used for executing the current ggml computation graph.
|
||||
std::unique_ptr<ggml_cann_graph> cann_graph;
|
||||
ggml_cann_graph_lru_cache graph_lru_cache;
|
||||
bool acl_graph_mode = true;
|
||||
#endif
|
||||
cann_task_queue task_queue;
|
||||
@@ -451,7 +526,10 @@ struct ggml_backend_cann_context {
|
||||
*/
|
||||
aclrtStream stream(int stream) {
|
||||
if (streams[stream] == nullptr) {
|
||||
ggml_cann_set_device(device);
|
||||
// If the device is not set here, destroying the stream later may cause a mismatch
|
||||
// between the thread contexts where the stream was created and destroyed.
|
||||
// However, I printed the device_id, thread_id, and stream, and they are all consistent.
|
||||
ACL_CHECK(aclrtSetDevice(device));
|
||||
ACL_CHECK(aclrtCreateStream(&streams[stream]));
|
||||
}
|
||||
return streams[stream];
|
||||
|
||||
@@ -75,13 +75,12 @@
|
||||
* @param device The device ID to set.
|
||||
*/
|
||||
void ggml_cann_set_device(const int32_t device) {
|
||||
// TODO: uncomment these lines after empty context has fixed.
|
||||
// int current_device;
|
||||
// ACL_CHECK(aclrtGetDevice(¤t_device));
|
||||
int current_device = -1;
|
||||
aclrtGetDevice(¤t_device);
|
||||
|
||||
// if (device == current_device) {
|
||||
// return;
|
||||
// }
|
||||
if (device == current_device) {
|
||||
return;
|
||||
}
|
||||
ACL_CHECK(aclrtSetDevice(device));
|
||||
}
|
||||
|
||||
@@ -116,6 +115,24 @@ bool parse_bool(const std::string& value) {
|
||||
return valid_values.find(value) != valid_values.end();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Parse a string as an integer, returning 0 if invalid.
|
||||
*
|
||||
* This function attempts to convert the input string `value` to an `int`.
|
||||
* If the string is not a valid integer or is out of the `int` range,
|
||||
* it returns 0.
|
||||
*
|
||||
* @param value The string to parse.
|
||||
* @return The parsed integer, or 0 if conversion fails.
|
||||
*/
|
||||
int parse_integer(const std::string& value) {
|
||||
try {
|
||||
return std::stoi(value);
|
||||
} catch (...) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Initialize the CANN device information.
|
||||
*
|
||||
@@ -2092,16 +2109,17 @@ static bool ggml_backend_cann_cpy_tensor_async(
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE,
|
||||
cann_ctx_src->stream()));
|
||||
|
||||
// record event on src stream after the copy
|
||||
if (!cann_ctx_src->copy_event) {
|
||||
ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC));
|
||||
}
|
||||
ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream()));
|
||||
// TODO: this event is not effective with acl graph mode, change to use aclrtSynchronizeStream
|
||||
// if (!cann_ctx_src->copy_event) {
|
||||
// ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC));
|
||||
// }
|
||||
// ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream()));
|
||||
|
||||
// wait on dst stream for the copy to complete
|
||||
ggml_cann_set_device(cann_ctx_dst->device);
|
||||
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event));
|
||||
// // wait on dst stream for the copy to complete
|
||||
// ggml_cann_set_device(cann_ctx_dst->device);
|
||||
// ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event));
|
||||
ACL_CHECK(aclrtSynchronizeStream(cann_ctx_src->stream()));
|
||||
} else {
|
||||
// src and dst are on the same backend
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
|
||||
@@ -2130,30 +2148,52 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
|
||||
|
||||
#ifdef USE_ACL_GRAPH
|
||||
/**
|
||||
* @brief Populate the internal CANN graph node properties from the ggml computation graph.
|
||||
* @brief Add a new CANN graph to the LRU cache by populating node properties from the ggml graph.
|
||||
*
|
||||
* This function copies all node attributes (operation type, dimensions, strides, input sources,
|
||||
* and operation parameters) into the cached CANN graph structure for later reuse or comparison.
|
||||
* This function creates a new ggml_cann_graph object and fills its node properties
|
||||
* (operation type, dimensions, strides, input sources, and operation parameters)
|
||||
* based on the current ggml computation graph.
|
||||
*
|
||||
* @param cann_ctx The CANN backend context.
|
||||
* @param cgraph The ggml computational graph.
|
||||
* Each node in the ggml graph is mapped to a property entry in the new CANN graph:
|
||||
* - node address
|
||||
* - operation type
|
||||
* - shape (ne) and strides (nb)
|
||||
* - source tensor addresses
|
||||
* - operation parameters
|
||||
*
|
||||
* After initialization, the new graph is pushed into the LRU cache owned by the
|
||||
* CANN backend context. The cache takes ownership of the graph and manages its
|
||||
* lifetime (including deletion upon eviction).
|
||||
*
|
||||
* @param cann_ctx The CANN backend context containing the graph cache.
|
||||
* @param cgraph The current ggml computation graph.
|
||||
*/
|
||||
static void set_ggml_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
|
||||
for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
|
||||
ggml_tensor * node = cgraph->nodes[node_idx];
|
||||
cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_address = node->data;
|
||||
cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_op = node->op;
|
||||
static void add_lru_matched_graph_node_properties(
|
||||
ggml_backend_cann_context * cann_ctx,
|
||||
ggml_cgraph * cgraph) {
|
||||
// Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache).
|
||||
ggml_cann_graph * new_graph = new ggml_cann_graph();
|
||||
new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
|
||||
|
||||
for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
|
||||
cann_ctx->cann_graph->ggml_graph_properties[node_idx].ne[dim] = node->ne[dim];
|
||||
cann_ctx->cann_graph->ggml_graph_properties[node_idx].nb[dim] = node->nb[dim];
|
||||
for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
|
||||
ggml_tensor * node = cgraph->nodes[node_idx];
|
||||
auto & prop = new_graph->ggml_graph_properties[node_idx];
|
||||
|
||||
prop.node_address = node->data;
|
||||
prop.node_op = node->op;
|
||||
|
||||
std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
|
||||
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
|
||||
|
||||
for (int src = 0; src < GGML_MAX_SRC; ++src) {
|
||||
prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
|
||||
}
|
||||
for (int src = 0; src < GGML_MAX_SRC; src++) {
|
||||
cann_ctx->cann_graph->ggml_graph_properties[node_idx].src_address[src] =
|
||||
node->src[src] ? node->src[src]->data : nullptr;
|
||||
}
|
||||
memcpy(cann_ctx->cann_graph->ggml_graph_properties[node_idx].op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||
|
||||
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||
}
|
||||
|
||||
// Insert into the LRU cache (cache takes ownership and will delete it when evicted).
|
||||
cann_ctx->graph_lru_cache.push(new_graph);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -2198,30 +2238,45 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Determine if the CANN graph needs to be rebuilt due to graph changes.
|
||||
* @brief Check whether there is a cached CANN graph that matches the current ggml graph.
|
||||
*
|
||||
* This checks whether the number or properties of ggml graph nodes have changed
|
||||
* compared to the last captured CANN graph. If so, the CANN graph must be re-captured.
|
||||
* This function iterates through the cached CANN graphs stored in the LRU cache and
|
||||
* compares them against the given ggml computation graph. A match requires that the
|
||||
* number of nodes is the same and that each node’s properties (operation type,
|
||||
* dimensions, strides, inputs, and operation parameters) are identical.
|
||||
*
|
||||
* @param cann_ctx The CANN backend context.
|
||||
* If a matching graph is found, it is promoted to the front of the LRU cache and the
|
||||
* function returns true. Otherwise, the function returns false, indicating that a new
|
||||
* CANN graph needs to be captured.
|
||||
*
|
||||
* @param cann_ctx The CANN backend context containing the graph cache.
|
||||
* @param cgraph The current ggml computation graph.
|
||||
* @return true if an update is required; false otherwise.
|
||||
* @return true if a matching cached graph exists; false otherwise.
|
||||
*/
|
||||
static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
|
||||
// The number of nodes is different, so the graph needs to be reconstructed.
|
||||
if (cann_ctx->cann_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
|
||||
cann_ctx->cann_graph->ggml_graph_properties.resize(cgraph->n_nodes);
|
||||
return true;
|
||||
}
|
||||
static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
|
||||
ggml_cann_graph_lru_cache &lru_cache = cann_ctx->graph_lru_cache;
|
||||
for (auto &graph_ptr : lru_cache.cache_list) {
|
||||
// Skip graphs with a different number of nodes.
|
||||
if (graph_ptr->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// The number of nodes is the same; iterate over each node to check whether they match.
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
bool has_matching_properties = ggml_graph_node_has_matching_properties(
|
||||
cgraph->nodes[i], &cann_ctx->cann_graph->ggml_graph_properties[i]);
|
||||
if(!has_matching_properties) {
|
||||
// Check if all nodes match.
|
||||
bool all_match = true;
|
||||
for (int i = 0; i < cgraph->n_nodes; ++i) {
|
||||
if (!ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i])) {
|
||||
all_match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (all_match) {
|
||||
// update cache_list && renturn graph_ptr
|
||||
lru_cache.move_to_front(graph_ptr);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
#endif // USE_ACL_GRAPH
|
||||
@@ -2240,17 +2295,13 @@ static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx,
|
||||
* @param cann_graph_update_required Whether graph capture is needed due to graph changes.
|
||||
*/
|
||||
static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph,
|
||||
bool & use_cann_graph, bool & cann_graph_update_required) {
|
||||
bool & use_cann_graph, bool & cann_graph_update_required) {
|
||||
#ifdef USE_ACL_GRAPH
|
||||
ggml_cann_graph* matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
|
||||
if (use_cann_graph && cann_graph_update_required) {
|
||||
if (cann_ctx->cann_graph->graph != nullptr) {
|
||||
ACL_CHECK(aclmdlRIDestroy(cann_ctx->cann_graph->graph));
|
||||
cann_ctx->cann_graph->graph = nullptr;
|
||||
}
|
||||
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
|
||||
}
|
||||
#endif // USE_ACL_GRAPH
|
||||
|
||||
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
|
||||
// With the use of CANN graphs, the execution will be performed by the graph launch.
|
||||
if (!use_cann_graph || cann_graph_update_required) {
|
||||
@@ -2271,12 +2322,12 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
||||
|
||||
#ifdef USE_ACL_GRAPH
|
||||
if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
|
||||
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &cann_ctx->cann_graph->graph));
|
||||
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
|
||||
}
|
||||
|
||||
if (use_cann_graph) {
|
||||
// Execute graph
|
||||
ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->cann_graph->graph, cann_ctx->stream()));
|
||||
ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream()));
|
||||
}
|
||||
#endif // USE_ACL_GRAPH
|
||||
}
|
||||
@@ -2301,28 +2352,44 @@ static enum ggml_status ggml_backend_cann_graph_compute(
|
||||
ggml_cann_set_device(cann_ctx->device);
|
||||
g_nz_workspaces[cann_ctx->device].clear();
|
||||
|
||||
// calculate rope cache for fist layer in current device.
|
||||
cann_ctx->rope_cache.cached = false;
|
||||
|
||||
#ifdef USE_ACL_GRAPH
|
||||
bool use_cann_graph = true;
|
||||
bool cann_graph_update_required = false;
|
||||
|
||||
static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
|
||||
if (!prefill_use_graph) {
|
||||
// Do not use acl_graph for prefill.
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
// TODO: Optimize here. Currently, we can only
|
||||
// get seq_len by FA's input.
|
||||
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
// Q -> src[0], shape: [B, S, N, D]
|
||||
use_cann_graph = (node->src[0]->ne[1] == 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!cann_ctx->acl_graph_mode) {
|
||||
use_cann_graph = false;
|
||||
}
|
||||
|
||||
if (use_cann_graph) {
|
||||
if (cann_ctx->cann_graph == nullptr) {
|
||||
cann_ctx->cann_graph.reset(new ggml_cann_graph());
|
||||
cann_graph_update_required = true;
|
||||
// If no matching graph is found, the graph needs to be recaptured.
|
||||
cann_graph_update_required = !is_matched_graph(cann_ctx, cgraph);
|
||||
if (cann_graph_update_required) {
|
||||
// If no matching graph is found, add a new ACL graph.
|
||||
add_lru_matched_graph_node_properties(cann_ctx, cgraph);
|
||||
}
|
||||
|
||||
cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph);
|
||||
set_ggml_graph_node_properties(cann_ctx, cgraph);
|
||||
}
|
||||
#else
|
||||
bool use_cann_graph = false;
|
||||
bool cann_graph_update_required = false;
|
||||
#endif // USE_ACL_GRAPH
|
||||
|
||||
evaluate_and_capture_cann_graph(
|
||||
cann_ctx,
|
||||
cgraph,
|
||||
@@ -2689,6 +2756,7 @@ static const ggml_backend_i ggml_backend_cann_interface = {
|
||||
/* .graph_compute = */ ggml_backend_cann_graph_compute,
|
||||
/* .event_record = */ ggml_backend_cann_event_record,
|
||||
/* .event_wait = */ ggml_backend_cann_event_wait,
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -224,7 +224,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
||||
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
|
||||
if (NOT ${feature_pos} EQUAL -1)
|
||||
message(STATUS "ARM feature ${feature} enabled")
|
||||
# Special handling for MATMUL_INT8 when machine doesn't support i8mm
|
||||
if ("${feature}" STREQUAL "MATMUL_INT8" AND GGML_MACHINE_SUPPORTS_noi8mm)
|
||||
message(STATUS "ARM feature ${feature} detected but unsetting due to machine not supporting i8mm")
|
||||
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_MATMUL_INT8)
|
||||
else()
|
||||
message(STATUS "ARM feature ${feature} enabled")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "ggml-cpu.h"
|
||||
#include "traits.h"
|
||||
|
||||
#if defined(__gnu_linux__)
|
||||
#if defined(__linux__)
|
||||
#include <sys/syscall.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
@@ -186,7 +186,7 @@ static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_ty
|
||||
#define XFEATURE_XTILEDATA 18
|
||||
|
||||
static bool ggml_amx_init() {
|
||||
#if defined(__gnu_linux__)
|
||||
#if defined(__linux__)
|
||||
if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
|
||||
fprintf(stderr, "AMX is not ready to be used!\n");
|
||||
return false;
|
||||
@@ -194,6 +194,8 @@ static bool ggml_amx_init() {
|
||||
return true;
|
||||
#elif defined(_WIN32)
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,14 @@ static inline float bf16_to_f32(ggml_bf16_t x) {
|
||||
return GGML_BF16_TO_FP32(x);
|
||||
}
|
||||
|
||||
static inline float i32_to_f32(int32_t x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline int32_t f32_to_i32(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline float f32_to_f32(float x) {
|
||||
return x;
|
||||
}
|
||||
@@ -54,6 +62,12 @@ struct type_conversion_table<ggml_bf16_t> {
|
||||
static constexpr ggml_bf16_t (*from_f32)(float) = f32_to_bf16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_conversion_table<int32_t> {
|
||||
static constexpr float (*to_f32)(int32_t) = i32_to_f32;
|
||||
static constexpr int32_t (*from_f32)(float) = f32_to_i32;
|
||||
};
|
||||
|
||||
static std::pair<int64_t, int64_t> get_thread_range(const struct ggml_compute_params * params, const struct ggml_tensor * src0) {
|
||||
const int64_t ith = params->ith;
|
||||
const int64_t nth = params->nth;
|
||||
|
||||
@@ -373,6 +373,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_I32] = {
|
||||
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32,
|
||||
},
|
||||
};
|
||||
|
||||
const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
|
||||
@@ -2696,7 +2699,10 @@ struct ggml_cplan ggml_graph_plan(
|
||||
if (ggml_is_quantized(node->type) ||
|
||||
// F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
|
||||
(node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
|
||||
(node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
|
||||
(node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16) ||
|
||||
// conversion between F32 and I32
|
||||
(node->src[0]->type == GGML_TYPE_F32 && node->src[1] && node->src[1]->type == GGML_TYPE_I32) ||
|
||||
(node->src[0]->type == GGML_TYPE_I32 && node->src[1] && node->src[1]->type == GGML_TYPE_F32)) {
|
||||
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
||||
}
|
||||
} break;
|
||||
@@ -3258,6 +3264,13 @@ void ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) {
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cpu_fp32_to_i32(const float * x, int32_t * y, int64_t n) {
|
||||
int64_t i = 0;
|
||||
for (; i < n; ++i) {
|
||||
y[i] = x[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) {
|
||||
int64_t i = 0;
|
||||
#if defined(__AVX2__)
|
||||
|
||||
@@ -190,6 +190,7 @@ static const struct ggml_backend_i ggml_backend_cpu_i = {
|
||||
/* .graph_compute = */ ggml_backend_cpu_graph_compute,
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_cpu_guid(void) {
|
||||
|
||||
@@ -515,9 +515,6 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
op->src[0]->buffer &&
|
||||
(ggml_n_dims(op->src[0]) == 2) &&
|
||||
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
||||
if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -41,13 +41,15 @@ static void ggml_compute_forward_dup_same_cont(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_dup_f16(
|
||||
template<typename src_t, typename dst_t>
|
||||
static void ggml_compute_forward_dup_flt(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
||||
GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
@@ -62,6 +64,7 @@ static void ggml_compute_forward_dup_f16(
|
||||
const int ir0 = dr * ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
// case: type & row size equal
|
||||
if (src0->type == dst->type &&
|
||||
ne00 == ne0 &&
|
||||
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
|
||||
@@ -80,11 +83,11 @@ static void ggml_compute_forward_dup_f16(
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
||||
|
||||
// case: dst tensor is contiguous
|
||||
if (ggml_is_contiguous(dst)) {
|
||||
if (nb00 == sizeof(ggml_fp16_t)) {
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
if (nb00 == sizeof(src_t)) {
|
||||
if constexpr (std::is_same_v<dst_t, src_t>) {
|
||||
// same type
|
||||
size_t id = 0;
|
||||
const size_t rs = ne00 * nb00;
|
||||
char * dst_ptr = (char *) dst->data;
|
||||
@@ -100,91 +103,46 @@ static void ggml_compute_forward_dup_f16(
|
||||
id += rs * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_F32) {
|
||||
} else {
|
||||
// casting between non-quantized types
|
||||
size_t id = 0;
|
||||
float * dst_ptr = (float *) dst->data;
|
||||
dst_t * dst_ptr = (dst_t *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += ne00 * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
|
||||
float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
|
||||
dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
|
||||
id++;
|
||||
}
|
||||
}
|
||||
id += ne00 * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
||||
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
||||
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
||||
|
||||
size_t id = 0;
|
||||
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
||||
char * dst_ptr = (char *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += rs * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
|
||||
}
|
||||
|
||||
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
||||
id += rs;
|
||||
}
|
||||
id += rs * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("fatal error"); // TODO: implement
|
||||
}
|
||||
} else {
|
||||
//printf("%s: this is not optimal - fix me\n", __func__);
|
||||
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
size_t id = 0;
|
||||
float * dst_ptr = (float *) dst->data;
|
||||
size_t id = 0;
|
||||
dst_t * dst_ptr = (dst_t *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += ne00 * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += ne00 * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr);
|
||||
id++;
|
||||
}
|
||||
float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
|
||||
dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
|
||||
id++;
|
||||
}
|
||||
id += ne00 * (ne01 - ir1);
|
||||
}
|
||||
id += ne00 * (ne01 - ir1);
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_F16) {
|
||||
size_t id = 0;
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += ne00 * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
dst_ptr[id] = *src0_ptr;
|
||||
id++;
|
||||
}
|
||||
}
|
||||
id += ne00 * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("fatal error"); // TODO: implement
|
||||
}
|
||||
}
|
||||
return;
|
||||
@@ -196,7 +154,7 @@ static void ggml_compute_forward_dup_f16(
|
||||
int64_t i12 = 0;
|
||||
int64_t i13 = 0;
|
||||
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
if constexpr (std::is_same_v<dst_t, src_t>) {
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
i10 += ne00 * ir0;
|
||||
@@ -217,7 +175,7 @@ static void ggml_compute_forward_dup_f16(
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
|
||||
memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
|
||||
|
||||
if (++i10 == ne00) {
|
||||
i10 = 0;
|
||||
@@ -248,7 +206,8 @@ static void ggml_compute_forward_dup_f16(
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_F32) {
|
||||
|
||||
} else {
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
i10 += ne00 * ir0;
|
||||
@@ -269,7 +228,8 @@ static void ggml_compute_forward_dup_f16(
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
*(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
|
||||
float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
|
||||
*(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
|
||||
|
||||
if (++i10 == ne0) {
|
||||
i10 = 0;
|
||||
@@ -300,18 +260,19 @@ static void ggml_compute_forward_dup_f16(
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("fatal error"); // TODO: implement
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_dup_bf16(
|
||||
|
||||
template<typename src_t>
|
||||
static void ggml_compute_forward_dup_to_q(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
||||
GGML_ASSERT(!ggml_is_quantized(src0->type));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
@@ -326,629 +287,36 @@ static void ggml_compute_forward_dup_bf16(
|
||||
const int ir0 = dr * ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
if (src0->type == dst->type &&
|
||||
ne00 == ne0 &&
|
||||
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
|
||||
// copy by rows
|
||||
const size_t rs = ne00*nb00;
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
||||
memcpy(
|
||||
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
||||
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
||||
rs);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (ggml_is_contiguous(dst) &&
|
||||
nb00 == sizeof(src_t) &&
|
||||
ggml_get_type_traits_cpu(dst->type)->from_float) {
|
||||
// casting non-quantized types --> intermediate f32 --> quantized
|
||||
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
||||
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
||||
|
||||
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
||||
size_t id = 0;
|
||||
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
||||
char * dst_ptr = (char *) dst->data;
|
||||
|
||||
if (ggml_is_contiguous(dst)) {
|
||||
if (nb00 == sizeof(ggml_bf16_t)) {
|
||||
if (dst->type == GGML_TYPE_BF16) {
|
||||
size_t id = 0;
|
||||
const size_t rs = ne00 * nb00;
|
||||
char * dst_ptr = (char *) dst->data;
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += rs * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += rs * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
memcpy(dst_ptr + id, src0_ptr, rs);
|
||||
id += rs;
|
||||
}
|
||||
id += rs * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_F16) {
|
||||
size_t id = 0;
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += ne00 * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
|
||||
id++;
|
||||
}
|
||||
}
|
||||
id += ne00 * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_F32) {
|
||||
size_t id = 0;
|
||||
float * dst_ptr = (float *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += ne00 * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
|
||||
id++;
|
||||
}
|
||||
}
|
||||
id += ne00 * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
||||
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
||||
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
||||
|
||||
size_t id = 0;
|
||||
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
||||
char * dst_ptr = (char *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += rs * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
|
||||
}
|
||||
|
||||
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
||||
id += rs;
|
||||
}
|
||||
id += rs * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("fatal error"); // TODO: implement
|
||||
}
|
||||
} else {
|
||||
//printf("%s: this is not optimal - fix me\n", __func__);
|
||||
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
size_t id = 0;
|
||||
float * dst_ptr = (float *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += ne00 * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
|
||||
id++;
|
||||
}
|
||||
}
|
||||
id += ne00 * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_BF16) {
|
||||
size_t id = 0;
|
||||
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += ne00 * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
dst_ptr[id] = *src0_ptr;
|
||||
id++;
|
||||
}
|
||||
}
|
||||
id += ne00 * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_F16) {
|
||||
size_t id = 0;
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += ne00 * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
|
||||
id++;
|
||||
}
|
||||
}
|
||||
id += ne00 * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("fatal error"); // TODO: implement
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// dst counters
|
||||
int64_t i10 = 0;
|
||||
int64_t i11 = 0;
|
||||
int64_t i12 = 0;
|
||||
int64_t i13 = 0;
|
||||
|
||||
if (dst->type == GGML_TYPE_BF16) {
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
i10 += ne00 * ir0;
|
||||
while (i10 >= ne0) {
|
||||
i10 -= ne0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
|
||||
|
||||
if (++i10 == ne00) {
|
||||
i10 = 0;
|
||||
if (++i11 == ne01) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne02) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne03) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
i10 += ne00 * (ne01 - ir1);
|
||||
while (i10 >= ne0) {
|
||||
i10 -= ne0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_F16) {
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
i10 += ne00 * ir0;
|
||||
while (i10 >= ne0) {
|
||||
i10 -= ne0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
*(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
|
||||
|
||||
if (++i10 == ne0) {
|
||||
i10 = 0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
i10 += ne00 * (ne01 - ir1);
|
||||
while (i10 >= ne0) {
|
||||
i10 -= ne0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_F32) {
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
i10 += ne00 * ir0;
|
||||
while (i10 >= ne0) {
|
||||
i10 -= ne0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
*(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
|
||||
|
||||
if (++i10 == ne0) {
|
||||
i10 = 0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
i10 += ne00 * (ne01 - ir1);
|
||||
while (i10 >= ne0) {
|
||||
i10 -= ne0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
|
||||
}
|
||||
|
||||
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
||||
id += rs;
|
||||
}
|
||||
id += rs * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("fatal error"); // TODO: implement
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_dup_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
const int ith = params->ith; // thread index
|
||||
const int nth = params->nth; // number of threads
|
||||
|
||||
// parallelize by rows
|
||||
const int nr = ne01;
|
||||
// number of rows per thread
|
||||
const int dr = (nr + nth - 1) / nth;
|
||||
// row range for this thread
|
||||
const int ir0 = dr * ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
if (src0->type == dst->type &&
|
||||
ne00 == ne0 &&
|
||||
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
|
||||
// copy by rows
|
||||
const size_t rs = ne00*nb00;
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
||||
memcpy(
|
||||
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
||||
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
||||
rs);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (ggml_is_contiguous(dst)) {
|
||||
// TODO: simplify
|
||||
if (nb00 == sizeof(float)) {
|
||||
if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
||||
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
|
||||
|
||||
size_t id = 0;
|
||||
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
||||
char * dst_ptr = (char *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += rs * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
from_float(src0_ptr, dst_ptr + id, ne00);
|
||||
id += rs;
|
||||
}
|
||||
id += rs * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("fatal error"); // TODO: implement
|
||||
}
|
||||
} else {
|
||||
//printf("%s: this is not optimal - fix me\n", __func__);
|
||||
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
size_t id = 0;
|
||||
float * dst_ptr = (float *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += ne00 * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
dst_ptr[id] = *src0_ptr;
|
||||
id++;
|
||||
}
|
||||
}
|
||||
id += ne00 * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_F16) {
|
||||
size_t id = 0;
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += ne00 * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr);
|
||||
id++;
|
||||
}
|
||||
}
|
||||
id += ne00 * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_BF16) {
|
||||
size_t id = 0;
|
||||
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += ne00 * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
|
||||
id++;
|
||||
}
|
||||
}
|
||||
id += ne00 * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("fatal error"); // TODO: implement
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// dst counters
|
||||
|
||||
int64_t i10 = 0;
|
||||
int64_t i11 = 0;
|
||||
int64_t i12 = 0;
|
||||
int64_t i13 = 0;
|
||||
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
i10 += ne00 * ir0;
|
||||
while (i10 >= ne0) {
|
||||
i10 -= ne0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
memcpy(dst_ptr, src0_ptr, sizeof(float));
|
||||
|
||||
if (++i10 == ne0) {
|
||||
i10 = 0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
i10 += ne00 * (ne01 - ir1);
|
||||
while (i10 >= ne0) {
|
||||
i10 -= ne0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_F16) {
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
i10 += ne00 * ir0;
|
||||
while (i10 >= ne0) {
|
||||
i10 -= ne0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
*(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
|
||||
|
||||
if (++i10 == ne0) {
|
||||
i10 = 0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
i10 += ne00 * (ne01 - ir1);
|
||||
while (i10 >= ne0) {
|
||||
i10 -= ne0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_BF16) {
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
i10 += ne00 * ir0;
|
||||
while (i10 >= ne0) {
|
||||
i10 -= ne0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
*(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
|
||||
|
||||
if (++i10 == ne0) {
|
||||
i10 = 0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
i10 += ne00 * (ne01 - ir1);
|
||||
while (i10 >= ne0) {
|
||||
i10 -= ne0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("fatal error"); // TODO: implement
|
||||
// printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
|
||||
GGML_ABORT("not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1102,7 +470,7 @@ static void ggml_compute_forward_dup_bytes(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_dup_q(
|
||||
static void ggml_compute_forward_dup_from_q(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
@@ -1167,20 +535,35 @@ void ggml_compute_forward_dup(
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_dup_f16(params, dst);
|
||||
/**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
|
||||
else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
|
||||
else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_fp16_t, float >(params, dst);
|
||||
else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
ggml_compute_forward_dup_bf16(params, dst);
|
||||
/**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
|
||||
else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
|
||||
else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_bf16_t, float >(params, dst);
|
||||
else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_dup_f32(params, dst);
|
||||
/**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
|
||||
else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
|
||||
else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<float, float >(params, dst);
|
||||
else if (dst->type == GGML_TYPE_I32) ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
|
||||
else ggml_compute_forward_dup_to_q<float>(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_I32:
|
||||
{
|
||||
if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
|
||||
else GGML_ABORT("not implemented");
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
|
||||
ggml_compute_forward_dup_q(params, dst);
|
||||
ggml_compute_forward_dup_from_q(params, dst);
|
||||
break;
|
||||
}
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -8438,7 +7821,7 @@ static void ggml_compute_forward_timestep_embedding_f32(
|
||||
embed_data[j + half] = sinf(arg);
|
||||
}
|
||||
if (dim % 2 != 0 && ith == 0) {
|
||||
embed_data[dim] = 0.f;
|
||||
embed_data[2 * half] = 0.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/mmq*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/mmf*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
|
||||
if (GGML_CUDA_FA_ALL_QUANTS)
|
||||
file(GLOB SRCS "template-instances/fattn-vec*.cu")
|
||||
|
||||
@@ -23,28 +23,44 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
|
||||
return a / b;
|
||||
}
|
||||
|
||||
template <float (*bin_op)(const float, const float),
|
||||
typename src0_t,
|
||||
typename src1_t,
|
||||
typename dst_t,
|
||||
typename... src1_ptrs>
|
||||
static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
const src1_t * src1,
|
||||
dst_t * dst,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const uint3 ne3,
|
||||
const uint3 ne10,
|
||||
const uint3 ne11,
|
||||
const uint3 ne12,
|
||||
const uint3 ne13,
|
||||
/*int s0, */ const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
/*int s00,*/ const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
/*int s10,*/ const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
src1_ptrs... src1s) {
|
||||
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
|
||||
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
|
||||
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
|
||||
|
||||
|
||||
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
|
||||
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
/*int s0, */ const int s1, const int s2, const int s3,
|
||||
/*int s00,*/ const int s01, const int s02, const int s03,
|
||||
/*int s10,*/ const int s11, const int s12, const int s13,
|
||||
src1_ptrs... src1s) {
|
||||
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
|
||||
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
|
||||
const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
|
||||
|
||||
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i11 = i1 % ne11;
|
||||
const int i12 = i2 % ne12;
|
||||
const int i13 = i3 % ne13;
|
||||
const uint32_t i11 = fastmodulo(i1, ne11);
|
||||
const uint32_t i12 = fastmodulo(i2, ne12);
|
||||
const uint32_t i13 = fastmodulo(i3, ne13);
|
||||
|
||||
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
@@ -53,8 +69,8 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
|
||||
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
|
||||
const int i10 = i0 % ne10;
|
||||
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
|
||||
const uint32_t i10 = fastmodulo(i0, ne10);
|
||||
|
||||
float result = src0_row ? (float) src0_row[i0] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
@@ -67,28 +83,48 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
|
||||
}
|
||||
}
|
||||
|
||||
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
|
||||
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
const int ne0, const int ne1, const int ne2,const int ne3,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
/*int s0, */ const int s1, const int s2, const int s3,
|
||||
/*int s00,*/ const int s01, const int s02, const int s03,
|
||||
/*int s10,*/ const int s11, const int s12, const int s13,
|
||||
src1_ptrs ... src1s) {
|
||||
template <float (*bin_op)(const float, const float),
|
||||
typename src0_t,
|
||||
typename src1_t,
|
||||
typename dst_t,
|
||||
typename... src1_ptrs>
|
||||
static __global__ void k_bin_bcast_unravel(const src0_t * src0,
|
||||
const src1_t * src1,
|
||||
dst_t * dst,
|
||||
const uint3 ne0,
|
||||
const uint3 ne1,
|
||||
const uint3 ne2,
|
||||
const uint32_t ne3,
|
||||
const uint3 prod_012,
|
||||
const uint3 prod_01,
|
||||
const uint3 ne10,
|
||||
const uint3 ne11,
|
||||
const uint3 ne12,
|
||||
const uint3 ne13,
|
||||
/*int s0, */ const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
/*int s00,*/ const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
/*int s10,*/ const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
src1_ptrs... src1s) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const int i3 = i/(ne2*ne1*ne0);
|
||||
const int i2 = (i/(ne1*ne0)) % ne2;
|
||||
const int i1 = (i/ne0) % ne1;
|
||||
const int i0 = i % ne0;
|
||||
const uint32_t i3 = fastdiv(i, prod_012);
|
||||
const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
|
||||
const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
|
||||
const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
|
||||
|
||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i11 = i1 % ne11;
|
||||
const int i12 = i2 % ne12;
|
||||
const int i13 = i3 % ne13;
|
||||
const int i11 = fastmodulo(i1, ne11);
|
||||
const int i12 = fastmodulo(i2, ne12);
|
||||
const int i13 = fastmodulo(i3, ne13);
|
||||
|
||||
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
@@ -97,7 +133,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
|
||||
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
const int i10 = i0 % ne10;
|
||||
const int i10 = fastmodulo(i0, ne10);
|
||||
|
||||
float result = src0_row ? (float) src0_row[i0] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
@@ -170,11 +206,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
|
||||
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
|
||||
|
||||
int64_t ne10 = cne1[0];
|
||||
int64_t ne11 = cne1[1];
|
||||
int64_t ne12 = cne1[2];
|
||||
int64_t ne13 = cne1[3];
|
||||
|
||||
size_t nb0 = cnb[0];
|
||||
size_t nb1 = cnb[1];
|
||||
size_t nb2 = cnb[2];
|
||||
@@ -233,48 +264,51 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
|
||||
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
|
||||
|
||||
dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x,
|
||||
(ne1 + block_dims.y - 1) / block_dims.y,
|
||||
dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
|
||||
(ne2 * ne3 + block_dims.z - 1) / block_dims.z);
|
||||
|
||||
const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]);
|
||||
const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]);
|
||||
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
|
||||
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
|
||||
|
||||
if (block_nums.z > 65535) {
|
||||
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
||||
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
||||
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
|
||||
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
|
||||
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
|
||||
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
|
||||
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
|
||||
|
||||
if constexpr (sizeof...(I) > 0) {
|
||||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
|
||||
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
|
||||
ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12,s13,
|
||||
(const src1_t *) dst->src[I + 1]->data...);
|
||||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
|
||||
ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
} else {
|
||||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
|
||||
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
|
||||
ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12,s13);
|
||||
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
|
||||
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13);
|
||||
}
|
||||
} else {
|
||||
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
|
||||
if constexpr (sizeof...(I) > 0) {
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
|
||||
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
|
||||
ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12,s13,
|
||||
(const src1_t *) dst->src[I + 1]->data...);
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
} else {
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
|
||||
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
|
||||
ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12,s13);
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,6 +75,8 @@
|
||||
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
|
||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
|
||||
#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)
|
||||
#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)
|
||||
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
|
||||
|
||||
// Moore Threads
|
||||
@@ -325,6 +327,20 @@ static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
||||
#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
|
||||
}
|
||||
|
||||
// Maximum number of bytes that can be copied in a single instruction.
|
||||
static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() {
|
||||
#ifdef GGML_USE_HIP
|
||||
return 16;
|
||||
#else
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
return 16;
|
||||
#else
|
||||
return 8;
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
#endif // GGML_USE_HIP
|
||||
}
|
||||
|
||||
|
||||
[[noreturn]]
|
||||
static __device__ void no_device_code(
|
||||
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
||||
@@ -545,6 +561,45 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) {
|
||||
acc += v*u;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) {
|
||||
acc += v.x*u.x;
|
||||
acc += v.y*u.y;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
|
||||
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
|
||||
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
|
||||
#else
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const float2 tmp = __half22float2(v*u);
|
||||
acc += tmp.x + tmp.y;
|
||||
#else
|
||||
const float2 tmpv = __half22float2(v);
|
||||
const float2 tmpu = __half22float2(u);
|
||||
acc += tmpv.x * tmpu.x;
|
||||
acc += tmpv.y * tmpu.y;
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
|
||||
}
|
||||
|
||||
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
|
||||
template <int nbytes>
|
||||
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
|
||||
if constexpr (nbytes == 4) {
|
||||
*(int *) dst = *(const int *) src;
|
||||
} else if constexpr (nbytes == 8) {
|
||||
*(int2 *) dst = *(const int2 *) src;
|
||||
} else if constexpr (nbytes == 16) {
|
||||
*(int4 *) dst = *(const int4 *) src;
|
||||
} else {
|
||||
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
||||
#if CUDART_VERSION >= 12080
|
||||
const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
|
||||
@@ -597,6 +652,14 @@ static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fa
|
||||
return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
|
||||
}
|
||||
|
||||
// Calculate both division and modulo at once, returns <n/divisor, n%divisor>
|
||||
static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) {
|
||||
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
|
||||
const uint32_t div_val = fastdiv(n, fastdiv_values);
|
||||
const uint32_t mod_val = n - div_val * fastdiv_values.z;
|
||||
return make_uint2(div_val, mod_val);
|
||||
}
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
|
||||
|
||||
static __device__ __forceinline__ float get_alibi_slope(
|
||||
|
||||
@@ -38,6 +38,8 @@ template<typename dst_t, typename src_t>
|
||||
return __float2bfloat16(float(x));
|
||||
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
|
||||
return __bfloat162float(x);
|
||||
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
|
||||
return int32_t(x);
|
||||
} else {
|
||||
return float(x);
|
||||
}
|
||||
|
||||
@@ -374,6 +374,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
@@ -437,6 +441,10 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
|
||||
@@ -647,9 +647,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
}
|
||||
|
||||
template<int D> // D == head size
|
||||
#if !defined(GGML_USE_HIP)
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIP)
|
||||
static __global__ void flash_attn_combine_results(
|
||||
const float * __restrict__ VKQ_parts,
|
||||
const float2 * __restrict__ VKQ_meta,
|
||||
@@ -692,10 +690,7 @@ static __global__ void flash_attn_combine_results(
|
||||
float VKQ_numerator = 0.0f;
|
||||
float VKQ_denominator = 0.0f;
|
||||
for (int l = 0; l < parallel_blocks; ++l) {
|
||||
const float diff = meta[l].x - kqmax;
|
||||
float KQ_max_scale = expf(diff);
|
||||
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||
const float KQ_max_scale = expf(meta[l].x - kqmax);
|
||||
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
|
||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||
@@ -836,11 +831,10 @@ void launch_fattn(
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
int parallel_blocks = 1;
|
||||
|
||||
const dim3 block_dim(warp_size, nwarps, 1);
|
||||
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||
int parallel_blocks = max_blocks_per_sm;
|
||||
|
||||
dim3 blocks_num;
|
||||
if (stream_k) {
|
||||
@@ -862,9 +856,6 @@ void launch_fattn(
|
||||
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
||||
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
|
||||
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
|
||||
|
||||
// parallel_blocks must not be larger than what the tensor size allows:
|
||||
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
||||
|
||||
|
||||
@@ -2,17 +2,30 @@
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-tile.cuh"
|
||||
|
||||
#define FATTN_TILE_NTHREADS 256
|
||||
// kq_stride == number of KQ rows to process per iteration
|
||||
// kq_nbatch == number of K columns to load in parallel for KQ calculation
|
||||
|
||||
static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
|
||||
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||
if (GGML_CUDA_CC_IS_RDNA(cc)) {
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 128;
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols <= 16 ? 32 : 64;
|
||||
return ncols == 32 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols <= 16 ? 64 : warp_size;
|
||||
return ncols == 32 ? 64 : 32;
|
||||
case 256:
|
||||
return 64;
|
||||
return 32;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
@@ -22,7 +35,6 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
|
||||
switch (D) {
|
||||
case 64:
|
||||
case 128:
|
||||
return 128;
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
@@ -41,26 +53,38 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
GGML_UNUSED(warp_size);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
|
||||
#ifdef GGML_USE_HIP
|
||||
#ifdef RDNA
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols <= 16 ? 32 : 64;
|
||||
return 128;
|
||||
case 128:
|
||||
return ncols <= 16 ? 64 : warp_size;
|
||||
case 256:
|
||||
return 64;
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols == 32 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols == 32 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // RDNA
|
||||
#else
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
switch (D) {
|
||||
case 64:
|
||||
case 128:
|
||||
return 128;
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
@@ -88,9 +112,8 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
return ncols <= 16 ? 2*warp_size : 128;
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 2*warp_size;
|
||||
return 128;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
@@ -100,9 +123,8 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
case 256:
|
||||
return ncols <= 16 ? 64 : 128;
|
||||
return 128;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
@@ -122,12 +144,27 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
|
||||
GGML_UNUSED_VARS(ncols, warp_size);
|
||||
}
|
||||
|
||||
template<int D, int ncols, bool use_logit_softcap> // D == head size
|
||||
#ifdef GGML_USE_HIP
|
||||
__launch_bounds__(FATTN_TILE_NTHREADS, 1)
|
||||
static int fattn_tile_get_nthreads_host(const int cc, const int ncols) {
|
||||
return 256;
|
||||
GGML_UNUSED_VARS(cc, ncols);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) {
|
||||
return 256;
|
||||
GGML_UNUSED(ncols);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) {
|
||||
#ifdef RDNA
|
||||
return 3;
|
||||
#else
|
||||
__launch_bounds__(FATTN_TILE_NTHREADS, 2)
|
||||
#endif // GGML_USE_HIP
|
||||
return ncols <= 16 ? 3 : 2;
|
||||
#endif // RDNA
|
||||
GGML_UNUSED(ncols);
|
||||
}
|
||||
|
||||
template<int D, int ncols, bool use_logit_softcap> // D == head size
|
||||
__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols))
|
||||
static __global__ void flash_attn_tile(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
@@ -173,7 +210,7 @@ static __global__ void flash_attn_tile(
|
||||
}
|
||||
|
||||
constexpr int warp_size = 32;
|
||||
constexpr int nwarps = FATTN_TILE_NTHREADS / warp_size;
|
||||
constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size;
|
||||
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
|
||||
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
|
||||
constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
|
||||
@@ -186,97 +223,140 @@ static __global__ void flash_attn_tile(
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
__shared__ float KQ[ncols][kq_stride];
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
constexpr int cpw = ncols/nwarps; // cols per warp
|
||||
|
||||
// softmax_iter_j == number of KQ columns for which to calculate softmax in parallel.
|
||||
// KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes.
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
|
||||
|
||||
__shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
||||
__shared__ half2 Q_tmp[ncols][D/2];
|
||||
__shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + 1)]; // Padded to avoid memory bank conflicts.
|
||||
half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
__shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
#else
|
||||
constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
|
||||
|
||||
__shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
||||
__shared__ float Q_tmp[ncols][D];
|
||||
__shared__ float KV_tmp_f[kq_stride * (kq_nbatch + 1)]; // Padded to avoid memory bank conflicts.
|
||||
float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
|
||||
float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
__shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j");
|
||||
|
||||
|
||||
float kqmax[ncols/nwarps];
|
||||
float KQ_max[cpw];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
kqmax[j0/nwarps] = -FLT_MAX/2.0f;
|
||||
KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
|
||||
}
|
||||
float kqsum[ncols/nwarps] = {0.0f};
|
||||
float KQ_sum[cpw] = {0.0f};
|
||||
|
||||
// Load Q data, convert to FP16 if fast.
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
const int j = j0 + threadIdx.y*cpw;
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
float tmp_f[cpy_ne_D] = {0.0f};
|
||||
if (ic0 + j < ne01) {
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0 + threadIdx.x] : make_float2(0.0f, 0.0f);
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp_f[i1] *= scale;
|
||||
}
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
Q_tmp[j][i0 + threadIdx.x] = make_half2(tmp.x * scale, tmp.y * scale);
|
||||
half2 tmp_h2[cpy_ne_D/2];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
|
||||
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2);
|
||||
#else
|
||||
Q_tmp[j][2*i0 + threadIdx.x] = tmp.x * scale;
|
||||
Q_tmp[j][2*i0 + warp_size + threadIdx.x] = tmp.y * scale;
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Main loop over KV cache:
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
||||
for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
float kqmax_new[ncols/nwarps];
|
||||
float KQ_max_new[cpw];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||
kqmax_new[j] = kqmax[j];
|
||||
for (int j = 0; j < cpw; ++j) {
|
||||
KQ_max_new[j] = KQ_max[j];
|
||||
}
|
||||
|
||||
float sum[kq_stride/warp_size][ncols/nwarps] = {{0.0f}};
|
||||
float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication.
|
||||
|
||||
// KQ = K @ Q matrix multiplication:
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) {
|
||||
const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x];
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1) + k_KQ_1 + threadIdx.x] = tmp_h2;
|
||||
#else
|
||||
const float2 tmp_f2 = __half22float2(tmp_h2);
|
||||
KV_tmp_f[i_KQ*(kq_nbatch + 1) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x;
|
||||
KV_tmp_f[i_KQ*(kq_nbatch + 1) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y;
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size);
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_kqnb*4>(
|
||||
&KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb],
|
||||
&K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size;
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
||||
half2 tmp_h2[cpy_ne_kqnb/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]);
|
||||
|
||||
float2 tmp_f2[cpy_ne_kqnb/2];
|
||||
#pragma unroll
|
||||
for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) {
|
||||
tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
||||
&KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; ++k_KQ_1) {
|
||||
half2 K_k[kq_stride/warp_size];
|
||||
half2 Q_k[ncols/nwarps];
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
|
||||
half2 K_k[kq_stride/warp_size][cpy_ne];
|
||||
half2 Q_k[cpw][cpy_ne];
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; ++k_KQ_1) {
|
||||
float K_k[kq_stride/warp_size];
|
||||
float Q_k[ncols/nwarps];
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
|
||||
float K_k[kq_stride/warp_size][cpy_ne];
|
||||
float Q_k[cpw][cpy_ne];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
#pragma unroll
|
||||
@@ -284,32 +364,30 @@ static __global__ void flash_attn_tile(
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
K_k[i_KQ_0/warp_size] = KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1) + k_KQ_1];
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
|
||||
#else
|
||||
K_k[i_KQ_0/warp_size] = KV_tmp_f [i_KQ*(kq_nbatch + 1) + k_KQ_1];
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1];
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
|
||||
#else
|
||||
Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0 + k_KQ_1];
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const float2 tmp = __half22float2(K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]);
|
||||
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += tmp.x + tmp.y;
|
||||
#else
|
||||
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < cpy_ne; ++k) {
|
||||
ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -319,64 +397,77 @@ static __global__ void flash_attn_tile(
|
||||
}
|
||||
}
|
||||
|
||||
// Apply logit softcap, mask, update KQ_max:
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
||||
|
||||
if (use_logit_softcap) {
|
||||
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
|
||||
KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
||||
}
|
||||
|
||||
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
|
||||
|
||||
KQ[j_KQ][i_KQ] = sum[i_KQ_0/warp_size][j_KQ_0/nwarps];
|
||||
KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half tmp[kq_stride/warp_size][softmax_iter_j];
|
||||
#else
|
||||
float tmp[kq_stride/warp_size][softmax_iter_j];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
kqmax_new[j0/nwarps] = warp_reduce_max<warp_size>(kqmax_new[j0/nwarps]);
|
||||
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
|
||||
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
||||
KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]);
|
||||
const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]);
|
||||
KQ_max[j0+j1] = KQ_max_new[j0+j1];
|
||||
|
||||
float KQ_sum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||
const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]);
|
||||
KQ_sum_add += val;
|
||||
tmp[i0/warp_size][j1] = val;
|
||||
}
|
||||
KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
float kqsum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const float diff = KQ[j][i] - kqmax[j0/nwarps];
|
||||
const float val = expf(diff);
|
||||
kqsum_add += val;
|
||||
KQ[j][i] = val;
|
||||
ggml_cuda_memcpy_1<sizeof(tmp[0])>(
|
||||
KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]);
|
||||
}
|
||||
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D;
|
||||
// VKQ = V @ KQ matrix multiplication:
|
||||
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K.
|
||||
static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
|
||||
@@ -384,66 +475,96 @@ static __global__ void flash_attn_tile(
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
|
||||
const int k_tile = k1 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const half2 tmp = V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i];
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
KV_tmp_h2[k_tile*(D/2) + i] = tmp;
|
||||
#else
|
||||
KV_tmp_f2[k_tile*(D/2) + i] = __half22float2(tmp);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D],
|
||||
&V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
half2 tmp_h2[cpy_ne_D/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]);
|
||||
|
||||
float2 tmp_f2[cpy_ne_D/2];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
|
||||
tmp_f2[i1] = __half22float2(tmp_h2[i1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
||||
&KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half2 V_k[(D/2)/warp_size];
|
||||
half2 KQ_k[ncols/nwarps];
|
||||
#else
|
||||
float2 V_k[(D/2)/warp_size];
|
||||
float KQ_k[ncols/nwarps];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
half2 KQ_k[cpw];
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
V_k[i0/warp_size] = KV_tmp_h2[k1*(D/2) + i];
|
||||
#else
|
||||
V_k[i0/warp_size] = KV_tmp_f2[k1*(D/2) + i];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const float tmp = KQ[j][k0 + k1];
|
||||
KQ_k[j0/nwarps] = make_half2(tmp, tmp);
|
||||
#else
|
||||
KQ_k[j0/nwarps] = KQ[j][k0 + k1];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
half tmp[softmax_iter_j];
|
||||
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>(
|
||||
&tmp, KQ[j][k0 + k1]);
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
||||
KQ_k[j0+j1] = __half2half2(tmp[j1]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
VKQ[j0/nwarps][i0/warp_size] += V_k[i0/warp_size] *KQ_k[j0/nwarps];
|
||||
#else
|
||||
VKQ[j0/nwarps][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0/nwarps];
|
||||
VKQ[j0/nwarps][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0/nwarps];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
||||
float2 V_k[(D/2)/warp_size];
|
||||
float KQ_k[cpw];
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
||||
|
||||
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>(
|
||||
&KQ_k[j0], KQ[j][k0 + k1]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0];
|
||||
VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -455,69 +576,92 @@ static __global__ void flash_attn_tile(
|
||||
const float sink = sinksf[head];
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
|
||||
kqmax_new_j = warp_reduce_max<warp_size>(kqmax_new_j);
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
float KQ_max_new_j = fmaxf(KQ_max[j0], sink);
|
||||
KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j);
|
||||
|
||||
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
|
||||
kqmax[j0/nwarps] = kqmax_new_j;
|
||||
const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j);
|
||||
KQ_max[j0] = KQ_max_new_j;
|
||||
|
||||
const float val = expf(sink - kqmax[j0/nwarps]);
|
||||
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
|
||||
const float val = expf(sink - KQ_max[j0]);
|
||||
KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
kqsum[j0/nwarps] += val;
|
||||
KQ_sum[j0] += val;
|
||||
}
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
|
||||
VKQ[j0][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
|
||||
VKQ[j0][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
float2 * dst2 = (float2 *) dst;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]);
|
||||
}
|
||||
if (gridDim.y == 1) {
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
||||
VKQ[j_VKQ_0][i] *= KQ_sum_j_inv;
|
||||
}
|
||||
#else
|
||||
const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
||||
VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv;
|
||||
VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
// Write back results:
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw;
|
||||
|
||||
if (ic0 + j_VKQ >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
||||
kqsum_j = warp_reduce_sum<warp_size>(kqsum_j);
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D/2; i00 += warp_size) {
|
||||
const int i0 = i00 + threadIdx.x;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
float2 dst_val = __half22float2(VKQ[j_VKQ_0/nwarps][i0/warp_size]);
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
float2 tmp[cpy_ne_D];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
|
||||
}
|
||||
#else
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/warp_size];
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
if (gridDim.y == 1) {
|
||||
dst_val.x /= kqsum_j;
|
||||
dst_val.y /= kqsum_j;
|
||||
}
|
||||
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]);
|
||||
}
|
||||
}
|
||||
#else
|
||||
@@ -538,15 +682,29 @@ template <int D, bool use_logit_softcap>
|
||||
static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int warp_size = 32;
|
||||
const int nwarps = FATTN_TILE_NTHREADS / warp_size;
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int warp_size = 32;
|
||||
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
|
||||
#ifdef GGML_USE_HIP
|
||||
if constexpr (D <= 128) {
|
||||
if (Q->ne[1] > 32) {
|
||||
constexpr int cols_per_block = 64;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
if (Q->ne[1] > 16) {
|
||||
constexpr int cols_per_block = 32;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
@@ -555,6 +713,7 @@ static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
|
||||
@@ -2,39 +2,39 @@
|
||||
#include "dequantize.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
#define MAX_GRIDDIM_Y 65535
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void k_get_rows(
|
||||
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
|
||||
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
|
||||
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
|
||||
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = blockIdx.z / ne12;
|
||||
const int i12 = blockIdx.z % ne12;
|
||||
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
|
||||
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = z / ne12; // TODO fastdiv
|
||||
const int i12 = z % ne12;
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
|
||||
const int ib = i00/qk; // block index
|
||||
const int iqs = (i00%qk)/qr; // quant index
|
||||
const int iybs = i00 - i00%qk; // dst block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
const int ib = i00/qk; // block index
|
||||
const int iqs = (i00%qk)/qr; // quant index
|
||||
const int iybs = i00 - i00%qk; // dst block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
float2 v;
|
||||
dequantize_kernel(src0_row, ib, iqs, v);
|
||||
// dequantize
|
||||
float2 v;
|
||||
dequantize_kernel(src0_row, ib, iqs, v);
|
||||
|
||||
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
||||
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,27 +42,29 @@ template<typename src0_t, typename dst_t>
|
||||
static __global__ void k_get_rows_float(
|
||||
const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
|
||||
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
|
||||
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
|
||||
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = blockIdx.z / ne12;
|
||||
const int i12 = blockIdx.z % ne12;
|
||||
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
|
||||
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = z / ne12; // TODO fastdiv
|
||||
const int i12 = z % ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
||||
|
||||
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
|
||||
}
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
||||
|
||||
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,7 +100,7 @@ static void get_rows_cuda_q(
|
||||
cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
||||
const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
|
||||
const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
|
||||
|
||||
// strides in elements
|
||||
// const size_t s0 = nb0 / sizeof(dst_t);
|
||||
@@ -116,7 +118,7 @@ static void get_rows_cuda_q(
|
||||
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10, ne11,*/ ne12, /*ne13,*/
|
||||
/*ne10,*/ ne11, ne12, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
@@ -131,7 +133,7 @@ static void get_rows_cuda_float(
|
||||
cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
|
||||
const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
|
||||
|
||||
// strides in elements
|
||||
// const size_t s0 = nb0 / sizeof(dst_t);
|
||||
@@ -147,7 +149,7 @@ static void get_rows_cuda_float(
|
||||
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10, ne11,*/ ne12, /*ne13,*/
|
||||
/*ne10,*/ ne11, ne12, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
|
||||
@@ -2109,6 +2109,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) {
|
||||
ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
@@ -3135,6 +3140,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
|
||||
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
|
||||
/* .event_record = */ ggml_backend_cuda_event_record,
|
||||
/* .event_wait = */ ggml_backend_cuda_event_wait,
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_cuda_guid() {
|
||||
@@ -3204,6 +3210,7 @@ struct ggml_backend_cuda_device_context {
|
||||
int device;
|
||||
std::string name;
|
||||
std::string description;
|
||||
std::string pci_bus_id;
|
||||
};
|
||||
|
||||
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
|
||||
@@ -3228,9 +3235,12 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
|
||||
props->name = ggml_backend_cuda_device_get_name(dev);
|
||||
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
|
||||
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
|
||||
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
||||
@@ -3461,6 +3471,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
|
||||
return true;
|
||||
}
|
||||
@@ -3574,9 +3590,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_PAD:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
@@ -3792,6 +3808,10 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
||||
char pci_bus_id[16] = {};
|
||||
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
|
||||
dev_ctx->pci_bus_id = pci_bus_id;
|
||||
|
||||
ggml_backend_dev_t dev = new ggml_backend_device {
|
||||
/* .iface = */ ggml_backend_cuda_device_interface,
|
||||
/* .reg = */ ®,
|
||||
|
||||
@@ -122,11 +122,14 @@ static __global__ void im2col_3d_kernel(
|
||||
int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
|
||||
int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
|
||||
int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
|
||||
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
|
||||
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (i >= IC_KD_KH_KW) {
|
||||
return;
|
||||
}
|
||||
GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH);
|
||||
GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW);
|
||||
|
||||
const int64_t iic = i / KD_KH_KW;
|
||||
const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
|
||||
@@ -148,7 +151,7 @@ static __global__ void im2col_3d_kernel(
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
||||
dst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
|
||||
const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
|
||||
dst[offset_dst] = src[offset_src];
|
||||
}
|
||||
}
|
||||
@@ -159,6 +162,7 @@ template <typename T>
|
||||
static void im2col_3d_cuda(const float * src, T* dst,
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
|
||||
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
|
||||
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
|
||||
const int64_t OH_OW = OH*OW;
|
||||
const int64_t KD_KH_KW = KD*KH*KW;
|
||||
@@ -179,23 +183,30 @@ static void im2col_3d_cuda(const float * src, T* dst,
|
||||
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
|
||||
IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
|
||||
OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
|
||||
stride_q, stride_z, stride_y, stride_x,
|
||||
s0, s1, s2, p0, p1, p2, d0, d1, d2);
|
||||
}
|
||||
|
||||
static void im2col_3d_cuda_f16(const float * src, half * dst,
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
|
||||
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
|
||||
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
|
||||
|
||||
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
|
||||
stride_q, stride_z, stride_y, stride_x,
|
||||
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
}
|
||||
|
||||
static void im2col_3d_cuda_f32(const float * src, float * dst,
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
|
||||
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
|
||||
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
|
||||
|
||||
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
|
||||
stride_q, stride_z, stride_y, stride_x,
|
||||
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
@@ -235,9 +246,19 @@ void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
const int64_t OH = ne2;
|
||||
const int64_t OW = ne1;
|
||||
|
||||
const size_t es = ggml_element_size(src1);
|
||||
const int64_t stride_x = src1->nb[0] / es;
|
||||
const int64_t stride_y = src1->nb[1] / es;
|
||||
const int64_t stride_z = src1->nb[2] / es;
|
||||
const int64_t stride_q = src1->nb[3] / es;
|
||||
|
||||
if(dst->type == GGML_TYPE_F16) {
|
||||
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
|
||||
stride_q, stride_z, stride_y, stride_x,
|
||||
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
} else {
|
||||
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
|
||||
stride_q, stride_z, stride_y, stride_x,
|
||||
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#pragma once
|
||||
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
|
||||
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
|
||||
// The documentation for the PTX instructions can be found under:
|
||||
|
||||
@@ -1,343 +1,12 @@
|
||||
#include "ggml.h"
|
||||
#include "common.cuh"
|
||||
#include "mma.cuh"
|
||||
#include "mmf.cuh"
|
||||
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
#define MMF_ROWS_PER_BLOCK 32
|
||||
|
||||
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
|
||||
static __global__ void mul_mat_f(
|
||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
constexpr int ntA = rows_per_block / tile_A::I;
|
||||
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
||||
|
||||
const int row0 = blockIdx.x * rows_per_block;
|
||||
const int channel_dst = blockIdx.y;
|
||||
const int channel_x = channel_dst / channel_ratio;
|
||||
const int channel_y = channel_dst;
|
||||
const int sample_dst = blockIdx.z;
|
||||
const int sample_x = sample_dst / sample_ratio;
|
||||
const int sample_y = sample_dst;
|
||||
|
||||
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
|
||||
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
|
||||
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
|
||||
|
||||
const float2 * y2 = (const float2 *) y;
|
||||
|
||||
extern __shared__ char data_mmv[];
|
||||
|
||||
tile_C C[ntA][ntB];
|
||||
|
||||
T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
|
||||
|
||||
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
|
||||
tile_A A[ntA][warp_size / tile_A::J];
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tile_A::I; ++i) {
|
||||
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
|
||||
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int itB = 0; itB < ntB; ++itB) {
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const int j = j0 + itB*tile_B::I;
|
||||
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const int j = j0 + itB*tile_B::I;
|
||||
|
||||
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
|
||||
tile_B B;
|
||||
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float * buf_iw = (float *) data_mmv;
|
||||
constexpr int kiw = nwarps*rows_per_block + 4;
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
}
|
||||
#pragma unroll
|
||||
for (int itB = 0; itB < ntB; ++itB) {
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
|
||||
const int j = itB*tile_C::J + tile_C::get_j(l);
|
||||
buf_iw[j*kiw + i] = C[itA][itB].x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
static_assert(rows_per_block == warp_size, "need loop/check");
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
sum += buf_iw[j*kiw + i];
|
||||
}
|
||||
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids, dst,
|
||||
ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
template <typename T, int cols_per_block>
|
||||
static void mul_mat_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
|
||||
GGML_ASSERT(!ids && "mul_mat_id not implemented");
|
||||
|
||||
GGML_ASSERT(ncols_x % 2 == 0);
|
||||
GGML_ASSERT(stride_row % 2 == 0);
|
||||
GGML_ASSERT(stride_col_y % 2 == 0);
|
||||
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
||||
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
||||
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
||||
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
||||
|
||||
const int device = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
|
||||
int64_t nwarps_best = 1;
|
||||
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
|
||||
int64_t max_block_size = 256;
|
||||
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
|
||||
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
|
||||
if (niter < niter_best) {
|
||||
niter_best = niter;
|
||||
nwarps_best = nwarps;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
||||
const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
|
||||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||
const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
|
||||
const dim3 block_dims(warp_size, nwarps_best, 1);
|
||||
switch (nwarps_best) {
|
||||
case 1: {
|
||||
mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void mul_mat_f_switch_cols_per_block(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
switch (ncols_dst) {
|
||||
case 1: {
|
||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 9: {
|
||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 10: {
|
||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 11: {
|
||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 12: {
|
||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 13: {
|
||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 14: {
|
||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 15: {
|
||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 16: {
|
||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||
|
||||
const size_t ts_src0 = ggml_type_size(src0->type);
|
||||
@@ -365,55 +34,72 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
const int64_t s3 = dst->nb[3] / ts_dst;
|
||||
|
||||
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
|
||||
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
|
||||
|
||||
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
|
||||
const int64_t ncols_dst = ids ? ne2 : ne1;
|
||||
const int64_t nchannels_y = ids ? ne11 : ne12;
|
||||
const int64_t nchannels_dst = ids ? ne1 : ne2;
|
||||
const int64_t stride_channel_dst = ids ? s1 : s2;
|
||||
const int64_t stride_channel_y = ids ? s11 : s12;
|
||||
const int64_t nchannels_dst = ids ? ne1 : ne2;
|
||||
|
||||
GGML_ASSERT(!ids || ncols_dst == 1);
|
||||
const int64_t stride_col_dst = ids ? s2 : s1;
|
||||
const int64_t stride_col_y = ids ? s12 : s11;
|
||||
const int64_t stride_channel_dst = ids ? s1 : s2;
|
||||
|
||||
int64_t stride_channel_y = ids ? s11 : s12;
|
||||
int64_t nchannels_y = ids ? ne11 : ne12;
|
||||
|
||||
//mul_mat_id: handle broadcast
|
||||
if (ids && nchannels_y == 1) {
|
||||
stride_channel_y = 0;
|
||||
nchannels_y = ids->ne[0];
|
||||
}
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
constexpr int vals_per_T = 1;
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half2 * src0_d = (const half2 *) src0->data;
|
||||
constexpr int vals_per_T = 2;
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
||||
constexpr int vals_per_T = 2;
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) {
|
||||
|
||||
if (ggml_is_quantized(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
|
||||
return false;
|
||||
}
|
||||
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
|
||||
return false;
|
||||
}
|
||||
if (ne11 > 16) {
|
||||
if (src1_ncols > 16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
return ampere_mma_available(cc);
|
||||
|
||||
@@ -1,5 +1,461 @@
|
||||
#pragma once
|
||||
|
||||
#include "mma.cuh"
|
||||
#include "common.cuh"
|
||||
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
#define MMF_ROWS_PER_BLOCK 32
|
||||
|
||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
|
||||
|
||||
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
|
||||
static __global__ void mul_mat_f(
|
||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
||||
const int stride_col_id, const int stride_row_id,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
constexpr int ntA = rows_per_block / tile_A::I;
|
||||
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
||||
|
||||
const int row0 = blockIdx.x * rows_per_block;
|
||||
|
||||
const int expert_idx = has_ids ? blockIdx.y : 0;
|
||||
const int channel_dst = has_ids ? 0 : blockIdx.y;
|
||||
|
||||
const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
|
||||
const int channel_y = channel_dst;
|
||||
const int sample_dst = blockIdx.z;
|
||||
const int sample_x = sample_dst / sample_ratio;
|
||||
const int sample_y = sample_dst;
|
||||
|
||||
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
|
||||
y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
|
||||
dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
|
||||
|
||||
const float2 * y2 = (const float2 *) y;
|
||||
|
||||
extern __shared__ char data_mmv[];
|
||||
|
||||
char * shmem_base = data_mmv;
|
||||
int * slot_map = (int *) shmem_base;
|
||||
char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base;
|
||||
|
||||
tile_C C[ntA][ntB];
|
||||
|
||||
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
|
||||
|
||||
if constexpr (has_ids) {
|
||||
int found = 0;
|
||||
|
||||
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
slot_map[j] = -1;
|
||||
}
|
||||
|
||||
for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
|
||||
int match = id_row[k*stride_col_id] == expert_idx;
|
||||
|
||||
if (match) {
|
||||
slot_map[j] = k;
|
||||
found = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!__syncthreads_or(found)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
|
||||
tile_A A[ntA][warp_size / tile_A::J];
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tile_A::I; ++i) {
|
||||
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
|
||||
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int itB = 0; itB < ntB; ++itB) {
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const int j = j0 + itB*tile_B::I;
|
||||
|
||||
if constexpr (!has_ids) {
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
|
||||
} else {
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const int j = j0 + itB*tile_B::I;
|
||||
|
||||
if constexpr (!has_ids) {
|
||||
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
} else {
|
||||
float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
|
||||
tile_B B;
|
||||
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float * buf_iw = (float *) compute_base;
|
||||
constexpr int kiw = nwarps*rows_per_block + 4;
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
}
|
||||
#pragma unroll
|
||||
for (int itB = 0; itB < ntB; ++itB) {
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
|
||||
const int j = itB*tile_C::J + tile_C::get_j(l);
|
||||
buf_iw[j*kiw + i] = C[itA][itB].x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
static_assert(rows_per_block == warp_size, "need loop/check");
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
sum += buf_iw[j*kiw + i];
|
||||
}
|
||||
|
||||
if constexpr (!has_ids) {
|
||||
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
} else {
|
||||
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
|
||||
if (slot >= 0) {
|
||||
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids, dst,
|
||||
ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
template<typename T, int cols_per_block, int nwarps>
|
||||
static inline void mul_mat_f_switch_ids(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nchannels_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t stride_col_id, const int64_t stride_row_id,
|
||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
|
||||
if (ids) {
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} else {
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int cols_per_block>
|
||||
void mul_mat_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t stride_col_id, const int64_t stride_row_id,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
|
||||
GGML_ASSERT(ncols_x % 2 == 0);
|
||||
GGML_ASSERT(stride_row % 2 == 0);
|
||||
GGML_ASSERT(stride_col_y % 2 == 0);
|
||||
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
||||
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
||||
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
||||
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
||||
|
||||
const int device = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
|
||||
int64_t nwarps_best = 1;
|
||||
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
|
||||
int64_t max_block_size = 256;
|
||||
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
|
||||
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
|
||||
if (niter < niter_best) {
|
||||
niter_best = niter;
|
||||
nwarps_best = nwarps;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
||||
const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
|
||||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
||||
const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
|
||||
|
||||
const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
|
||||
const dim3 block_dims(warp_size, nwarps_best, 1);
|
||||
|
||||
switch (nwarps_best) {
|
||||
case 1: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
} break;
|
||||
}
|
||||
|
||||
GGML_UNUSED_VARS(nchannels_y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void mul_mat_f_switch_cols_per_block(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t stride_col_id, const int stride_row_id,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
switch (ncols_dst) {
|
||||
case 1: {
|
||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 9: {
|
||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 10: {
|
||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 11: {
|
||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 12: {
|
||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 13: {
|
||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 14: {
|
||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 15: {
|
||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 16: {
|
||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
|
||||
template void mul_mat_f_cuda<T, ncols_dst>( \
|
||||
const T * x, const float * y, const int32_t * ids, float * dst, \
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
|
||||
const int64_t stride_col_id, const int64_t stride_row_id, \
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
|
||||
cudaStream_t stream);
|
||||
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
|
||||
|
||||
#define DECL_MMF_CASE(ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(float, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(half2, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
|
||||
|
||||
DECL_MMF_CASE_EXTERN(1);
|
||||
DECL_MMF_CASE_EXTERN(2);
|
||||
DECL_MMF_CASE_EXTERN(3);
|
||||
DECL_MMF_CASE_EXTERN(4);
|
||||
DECL_MMF_CASE_EXTERN(5);
|
||||
DECL_MMF_CASE_EXTERN(6);
|
||||
DECL_MMF_CASE_EXTERN(7);
|
||||
DECL_MMF_CASE_EXTERN(8);
|
||||
DECL_MMF_CASE_EXTERN(9);
|
||||
DECL_MMF_CASE_EXTERN(10);
|
||||
DECL_MMF_CASE_EXTERN(11);
|
||||
DECL_MMF_CASE_EXTERN(12);
|
||||
DECL_MMF_CASE_EXTERN(13);
|
||||
DECL_MMF_CASE_EXTERN(14);
|
||||
DECL_MMF_CASE_EXTERN(15);
|
||||
DECL_MMF_CASE_EXTERN(16);
|
||||
#else
|
||||
#define DECL_MMF_CASE(ncols_dst)
|
||||
#endif
|
||||
|
||||
@@ -1,82 +1,89 @@
|
||||
#include "pad_reflect_1d.cuh"
|
||||
|
||||
static __global__ void pad_reflect_1d_kernel_f32(
|
||||
const void * __restrict__ src0,
|
||||
void * __restrict__ dst,
|
||||
const int64_t ne0,
|
||||
const int64_t ne00,
|
||||
const int64_t ne01,
|
||||
const int64_t ne02,
|
||||
const int64_t ne03,
|
||||
const int64_t nb00,
|
||||
const int64_t nb01,
|
||||
const int64_t nb02,
|
||||
const int64_t nb03,
|
||||
const int64_t nb0,
|
||||
const int64_t nb1,
|
||||
const int64_t nb2,
|
||||
const int64_t nb3,
|
||||
const int p0,
|
||||
const int p1) {
|
||||
|
||||
static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void
|
||||
pad_reflect_1d_kernel_f32(
|
||||
const void * __restrict__ src0,
|
||||
void * __restrict__ dst,
|
||||
const int64_t ne0,
|
||||
const int64_t ne00,
|
||||
const uint3 ne01,
|
||||
const int64_t ne02,
|
||||
const int64_t ne03,
|
||||
const int64_t nb00,
|
||||
const int64_t nb01,
|
||||
const int64_t nb02,
|
||||
const int64_t nb03,
|
||||
const int64_t nb0,
|
||||
const int64_t nb1,
|
||||
const int64_t nb2,
|
||||
const int64_t nb3,
|
||||
const int p0,
|
||||
const int p1) {
|
||||
const int64_t i3 = blockIdx.z;
|
||||
const int64_t i2 = blockIdx.y;
|
||||
const int64_t i1 = blockIdx.x;
|
||||
|
||||
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
|
||||
const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01);
|
||||
const int64_t tile1 = div_mod_packed.y; // i1
|
||||
const int64_t tile0 = div_mod_packed.x; // nth i0 tile
|
||||
const int64_t i1 = tile1;
|
||||
const int64_t i0 = threadIdx.x + tile0 * blockDim.x;
|
||||
|
||||
// ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh)
|
||||
if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) {
|
||||
return;
|
||||
}
|
||||
|
||||
const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
|
||||
char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
|
||||
const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
|
||||
char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
|
||||
|
||||
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
|
||||
float value;
|
||||
const int64_t rel_i0 = i0 - p0; // relative i0 in src0
|
||||
int64_t src_idx;
|
||||
|
||||
if (i0 < p0) {
|
||||
// Left padding - reflect
|
||||
value = *(const float *)(src0_ptr + (p0 - i0) * nb00);
|
||||
} else if (i0 < ne0 - p1) {
|
||||
// Middle - copy
|
||||
value = *(const float *)(src0_ptr + (i0 - p0) * nb00);
|
||||
} else {
|
||||
// Right padding - reflect
|
||||
int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
|
||||
value = *(const float *)(src0_ptr + src_idx * nb00);
|
||||
}
|
||||
|
||||
*(float *)(dst_ptr + i0 * nb0) = value;
|
||||
if (rel_i0 < 0) {
|
||||
// Left padding - reflect
|
||||
src_idx = -rel_i0;
|
||||
} else if (rel_i0 < ne00) {
|
||||
// Middle - copy
|
||||
src_idx = rel_i0;
|
||||
} else {
|
||||
// Right padding - reflect
|
||||
src_idx = 2 * ne00 - 2 - rel_i0;
|
||||
}
|
||||
const float value = *(const float *) (src0_ptr + src_idx * nb00);
|
||||
*(float *) (dst_ptr + i0 * nb0) = value;
|
||||
}
|
||||
|
||||
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
cudaStream_t stream = ctx.stream();
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int32_t * opts = (const int32_t *) dst->op_params;
|
||||
const int p0 = opts[0];
|
||||
const int p1 = opts[1];
|
||||
const int p0 = opts[0];
|
||||
const int p1 = opts[1];
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const uint3 ne01_packed = init_fastdiv_values(ne01);
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
|
||||
// sanity: padded length matches
|
||||
GGML_ASSERT(ne0 == ne00 + p0 + p1);
|
||||
|
||||
const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1);
|
||||
const dim3 grid_dims(ne01, ne02, ne03);
|
||||
constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x)
|
||||
const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0
|
||||
// grid.x covers i1 and all tiles of i0: [ne01 * tiles0]
|
||||
// grid.y covers i2: [ne02]
|
||||
// grid.z covers i3: [ne03]
|
||||
const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03);
|
||||
const dim3 block_dims((unsigned) bx, 1, 1);
|
||||
|
||||
pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
|
||||
src0->data, dst->data,
|
||||
ne0, ne00, ne01, ne02, ne03,
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
||||
p0, p1
|
||||
);
|
||||
src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1);
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ TYPES_MMQ = [
|
||||
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
||||
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
|
||||
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
|
||||
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
|
||||
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4"
|
||||
]
|
||||
|
||||
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
@@ -34,6 +34,13 @@ SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do
|
||||
DECL_MMQ_CASE({type});
|
||||
"""
|
||||
|
||||
SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE({type});
|
||||
"""
|
||||
|
||||
|
||||
def get_short_name(long_quant_name):
|
||||
return long_quant_name.replace("GGML_TYPE_", "").lower()
|
||||
@@ -76,3 +83,7 @@ for ncols in [8, 16, 32, 64]:
|
||||
for type in TYPES_MMQ:
|
||||
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|
||||
f.write(SOURCE_MMQ.format(type=type))
|
||||
|
||||
for type in range(1, 17):
|
||||
with open(f"mmf-instance-ncols_{type}.cu", "w") as f:
|
||||
f.write(SOURCE_MMF.format(type=type))
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(1);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(10);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(11);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(12);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(13);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(14);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(15);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(16);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(2);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(3);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(4);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(5);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(6);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(7);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(8);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(9);
|
||||
@@ -7,11 +7,11 @@ static __global__ void timestep_embedding_f32(const float * timesteps, float * d
|
||||
int j = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
float * embed_data = (float *)((char *)dst + i*nb1);
|
||||
|
||||
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
|
||||
embed_data[dim] = 0.f;
|
||||
int half = dim / 2;
|
||||
if (dim % 2 != 0 && j == half) {
|
||||
embed_data[2 * half] = 0.f;
|
||||
}
|
||||
|
||||
int half = dim / 2;
|
||||
if (j >= half) {
|
||||
return;
|
||||
}
|
||||
|
||||
36
ggml/src/ggml-cuda/vendors/hip.h
vendored
36
ggml/src/ggml-cuda/vendors/hip.h
vendored
@@ -158,33 +158,41 @@
|
||||
|
||||
#define __CUDA_ARCH__ 1300
|
||||
|
||||
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
|
||||
#define GCN
|
||||
#endif
|
||||
#if defined(__gfx900__) || defined(__gfx906__)
|
||||
#define GCN5
|
||||
#endif // defined(__gfx900__) || defined(__gfx906__)
|
||||
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
|
||||
#define CDNA // For the entire family
|
||||
#endif
|
||||
#if defined(__gfx803__)
|
||||
#define GCN4
|
||||
#endif // defined(__gfx803__)
|
||||
|
||||
#if defined(GCN5) || defined(GCN4)
|
||||
#define GCN
|
||||
#endif // defined(GCN5) || defined(GCN4)
|
||||
|
||||
#if defined(__gfx942__)
|
||||
#define CDNA3
|
||||
#endif
|
||||
#endif // defined(__gfx942__)
|
||||
|
||||
#if defined(__gfx90a__)
|
||||
#define CDNA2
|
||||
#endif
|
||||
#endif // defined(__gfx90a__)
|
||||
|
||||
#if defined(__gfx908__)
|
||||
#define CDNA1
|
||||
#endif
|
||||
#endif // defined(__gfx908__)
|
||||
|
||||
#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
|
||||
#define CDNA // For the entire family
|
||||
#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
|
||||
|
||||
#if defined(__GFX12__)
|
||||
#define RDNA4
|
||||
#endif
|
||||
#endif // defined(__GFX12__)
|
||||
|
||||
#if defined(__GFX11__)
|
||||
#define RDNA3
|
||||
#endif
|
||||
#endif // defined(__GFX11__)
|
||||
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
||||
@@ -193,7 +201,11 @@
|
||||
|
||||
#if defined(__gfx1010__) || defined(__gfx1012__)
|
||||
#define RDNA1
|
||||
#endif
|
||||
#endif // defined(__gfx1010__) || defined(__gfx1012__)
|
||||
|
||||
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
|
||||
#define RDNA // For the entire family
|
||||
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
|
||||
|
||||
#ifndef __has_builtin
|
||||
#define __has_builtin(x) 0
|
||||
|
||||
@@ -5,7 +5,12 @@ find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||
message(STATUS "Metal framework found")
|
||||
|
||||
ggml_add_backend_library(ggml-metal
|
||||
ggml-metal.m
|
||||
ggml-metal.cpp
|
||||
ggml-metal-device.m
|
||||
ggml-metal-device.cpp
|
||||
ggml-metal-common.cpp
|
||||
ggml-metal-context.m
|
||||
ggml-metal-ops.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(ggml-metal PRIVATE
|
||||
@@ -18,10 +23,6 @@ if (GGML_METAL_NDEBUG)
|
||||
add_compile_definitions(GGML_METAL_NDEBUG)
|
||||
endif()
|
||||
|
||||
if (GGML_METAL_USE_BF16)
|
||||
add_compile_definitions(GGML_METAL_USE_BF16)
|
||||
endif()
|
||||
|
||||
# copy metal files to bin directory
|
||||
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
|
||||
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
|
||||
|
||||
458
ggml/src/ggml-metal/ggml-metal-common.cpp
Normal file
458
ggml/src/ggml-metal/ggml-metal-common.cpp
Normal file
@@ -0,0 +1,458 @@
|
||||
#include "ggml-metal-common.h"
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
// represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)
|
||||
// the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)
|
||||
struct ggml_mem_range {
|
||||
uint64_t pb; // buffer id
|
||||
|
||||
uint64_t p0; // begin
|
||||
uint64_t p1; // end
|
||||
|
||||
ggml_mem_range_type pt;
|
||||
};
|
||||
|
||||
struct ggml_mem_ranges {
|
||||
std::vector<ggml_mem_range> ranges;
|
||||
|
||||
int debug = 0;
|
||||
};
|
||||
|
||||
ggml_mem_ranges_t ggml_mem_ranges_init(int debug) {
|
||||
auto * res = new ggml_mem_ranges;
|
||||
|
||||
res->ranges.reserve(256);
|
||||
res->debug = debug;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void ggml_mem_ranges_free(ggml_mem_ranges_t mrs) {
|
||||
delete mrs;
|
||||
}
|
||||
|
||||
void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs) {
|
||||
mrs->ranges.clear();
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, ggml_mem_range mr) {
|
||||
mrs->ranges.push_back(mr);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggml_mem_range_type pt) {
|
||||
// always use the base tensor
|
||||
tensor = tensor->view_src ? tensor->view_src : tensor;
|
||||
|
||||
GGML_ASSERT(!tensor->view_src);
|
||||
|
||||
ggml_mem_range mr;
|
||||
|
||||
if (tensor->buffer) {
|
||||
// when the tensor is allocated, use the actual memory address range in the buffer
|
||||
//
|
||||
// take the actual allocated size with ggml_backend_buft_get_alloc_size()
|
||||
// this can be larger than the tensor size if the buffer type allocates extra memory
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/15966
|
||||
mr = {
|
||||
/*.pb =*/ (uint64_t) tensor->buffer,
|
||||
/*.p0 =*/ (uint64_t) tensor->data,
|
||||
/*.p1 =*/ (uint64_t) tensor->data + ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor),
|
||||
/*.pt =*/ pt,
|
||||
};
|
||||
} else {
|
||||
// otherwise, the pointer address is used as an unique id of the memory ranges
|
||||
// that the tensor will be using when it is allocated
|
||||
mr = {
|
||||
/*.pb =*/ (uint64_t) tensor,
|
||||
/*.p0 =*/ 0, //
|
||||
/*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
|
||||
/*.pt =*/ pt,
|
||||
};
|
||||
};
|
||||
|
||||
return mr;
|
||||
}
|
||||
|
||||
static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) {
|
||||
return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC);
|
||||
}
|
||||
|
||||
static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) {
|
||||
return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_add_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor);
|
||||
|
||||
ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
|
||||
|
||||
if (mrs->debug > 2) {
|
||||
GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
|
||||
}
|
||||
|
||||
return ggml_mem_ranges_add(mrs, mr);
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor);
|
||||
|
||||
ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
|
||||
|
||||
if (mrs->debug > 2) {
|
||||
GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
|
||||
}
|
||||
|
||||
return ggml_mem_ranges_add(mrs, mr);
|
||||
}
|
||||
|
||||
bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (tensor->src[i]) {
|
||||
ggml_mem_ranges_add_src(mrs, tensor->src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return ggml_mem_ranges_add_dst(mrs, tensor);
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, ggml_mem_range mr) {
|
||||
for (size_t i = 0; i < mrs->ranges.size(); i++) {
|
||||
const auto & cmp = mrs->ranges[i];
|
||||
|
||||
// two memory ranges cannot intersect if they are in different buffers
|
||||
if (mr.pb != cmp.pb) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// intersecting source ranges are allowed
|
||||
if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) {
|
||||
if (mrs->debug > 2) {
|
||||
GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
|
||||
__func__,
|
||||
mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
|
||||
mr.pb, mr.p0, mr.p1,
|
||||
cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
|
||||
cmp.pb, cmp.p0, cmp.p1);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_check_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor);
|
||||
|
||||
ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
|
||||
|
||||
const bool res = ggml_mem_ranges_check(mrs, mr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor);
|
||||
|
||||
ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
|
||||
|
||||
const bool res = ggml_mem_ranges_check(mrs, mr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (tensor->src[i]) {
|
||||
if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ggml_mem_ranges_check_dst(mrs, tensor);
|
||||
}
|
||||
|
||||
// TODO: move to ggml.h?
|
||||
static bool is_empty(ggml_op op) {
|
||||
switch (op) {
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
struct node_info {
|
||||
ggml_tensor * node;
|
||||
|
||||
std::vector<ggml_tensor *> fused;
|
||||
|
||||
ggml_op op() const {
|
||||
return node->op;
|
||||
}
|
||||
|
||||
const ggml_tensor * dst() const {
|
||||
return fused.empty() ? node : fused.back();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return ::is_empty(node->op);
|
||||
}
|
||||
|
||||
void add_fused(ggml_tensor * t) {
|
||||
fused.push_back(t);
|
||||
}
|
||||
};
|
||||
|
||||
static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
|
||||
// helper to add node src and dst ranges
|
||||
const auto & h_add = [](ggml_mem_ranges_t mrs, const node_info & node) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node.node->src[i]) {
|
||||
if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// keep track of the sources of the fused nodes as well
|
||||
for (const auto * fused : node.fused) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (fused->src[i]) {
|
||||
if (!ggml_mem_ranges_add_src(mrs, fused->src[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ggml_mem_ranges_add_dst(mrs, node.dst());
|
||||
};
|
||||
|
||||
// helper to check if a node can run concurrently with the existing set of nodes
|
||||
const auto & h_check = [](ggml_mem_ranges_t mrs, const node_info & node) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node.node->src[i]) {
|
||||
if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto * fused : node.fused) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (fused->src[i]) {
|
||||
if (!ggml_mem_ranges_check_src(mrs, fused->src[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ggml_mem_ranges_check_dst(mrs, node.dst());
|
||||
};
|
||||
|
||||
// perform reorders only across these types of ops
|
||||
// can be expanded when needed
|
||||
// IMPORTANT: do not add ops such as GGML_OP_CPY or GGML_OP_SET_ROWS
|
||||
// the dependencies from such ops are not always represented in the graph
|
||||
const auto & h_safe = [](ggml_op op) {
|
||||
switch (op) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_GLU:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_GET_ROWS:
|
||||
return true;
|
||||
default:
|
||||
return is_empty(op);
|
||||
}
|
||||
};
|
||||
|
||||
const int n = nodes.size();
|
||||
|
||||
std::vector<int> res;
|
||||
res.reserve(n);
|
||||
|
||||
std::vector<bool> used(n, false);
|
||||
|
||||
// the memory ranges for the set of currently concurrent nodes
|
||||
ggml_mem_ranges_t mrs0 = ggml_mem_ranges_init(0);
|
||||
|
||||
// the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
|
||||
ggml_mem_ranges_t mrs1 = ggml_mem_ranges_init(0);
|
||||
|
||||
for (int i0 = 0; i0 < n; i0++) {
|
||||
if (used[i0]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto & node0 = nodes[i0];
|
||||
|
||||
// the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0)
|
||||
// but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0
|
||||
//
|
||||
// note: we can always add empty nodes to the concurrent set as they don't read nor write anything
|
||||
if (!node0.is_empty() && !h_check(mrs0, node0)) {
|
||||
// this will hold the set of memory ranges from the nodes that haven't been processed yet
|
||||
// if a node is not concurrent with this set, we cannot reorder it
|
||||
ggml_mem_ranges_reset(mrs1);
|
||||
|
||||
// initialize it with the current node
|
||||
h_add(mrs1, node0);
|
||||
|
||||
// that many nodes forward to search for a concurrent node
|
||||
constexpr int N_FORWARD = 8;
|
||||
|
||||
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
|
||||
if (used[i1]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto & node1 = nodes[i1];
|
||||
|
||||
// disallow reordering of certain ops
|
||||
if (!h_safe(node1.op())) {
|
||||
break;
|
||||
}
|
||||
|
||||
const bool is_empty = node1.is_empty();
|
||||
|
||||
// to reorder a node and add it to the concurrent set, it has to be:
|
||||
// + empty or concurrent with all nodes in the existing concurrent set (mrs0)
|
||||
// + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
|
||||
if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
|
||||
// add the node to the existing concurrent set (i.e. reorder it for early execution)
|
||||
h_add(mrs0, node1);
|
||||
res.push_back(i1);
|
||||
|
||||
// mark as used, so we skip re-processing it later
|
||||
used[i1] = true;
|
||||
} else {
|
||||
// expand the set of nodes that haven't been processed yet
|
||||
h_add(mrs1, node1);
|
||||
}
|
||||
}
|
||||
|
||||
// finalize the concurrent set and begin a new one
|
||||
ggml_mem_ranges_reset(mrs0);
|
||||
}
|
||||
|
||||
// expand the concurrent set with the current node
|
||||
{
|
||||
h_add(mrs0, node0);
|
||||
res.push_back(i0);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_mem_ranges_free(mrs0);
|
||||
ggml_mem_ranges_free(mrs1);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void ggml_graph_optimize(ggml_cgraph * gf) {
|
||||
constexpr int MAX_FUSE = 16;
|
||||
|
||||
const int n = gf->n_nodes;
|
||||
|
||||
enum ggml_op ops[MAX_FUSE];
|
||||
|
||||
std::vector<node_info> nodes;
|
||||
nodes.reserve(gf->n_nodes);
|
||||
|
||||
// fuse nodes:
|
||||
// we don't want to make reorders that break fusing, so we first pack all fusable tensors
|
||||
// and perform the reorder over the fused nodes. after the reorder is done, we unfuse
|
||||
for (int i = 0; i < n; i++) {
|
||||
node_info node = {
|
||||
/*.node =*/ gf->nodes[i],
|
||||
/*.fused =*/ {},
|
||||
};
|
||||
|
||||
// fuse only ops that start with these operations
|
||||
// can be expanded when needed
|
||||
if (node.op() == GGML_OP_ADD ||
|
||||
node.op() == GGML_OP_RMS_NORM) {
|
||||
ops[0] = node.op();
|
||||
|
||||
int f = i + 1;
|
||||
while (f < n && f < i + MAX_FUSE) {
|
||||
// conservatively allow fusing only these ops
|
||||
// can be expanded when needed
|
||||
if (gf->nodes[f]->op != GGML_OP_ADD &&
|
||||
gf->nodes[f]->op != GGML_OP_MUL &&
|
||||
gf->nodes[f]->op != GGML_OP_RMS_NORM) {
|
||||
break;
|
||||
}
|
||||
ops[f - i] = gf->nodes[f]->op;
|
||||
f++;
|
||||
}
|
||||
|
||||
f -= i;
|
||||
for (; f > 1; f--) {
|
||||
if (ggml_can_fuse(gf, i, ops, f)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// add the fused tensors into the node info so we can unfuse them later
|
||||
for (int k = 1; k < f; k++) {
|
||||
++i;
|
||||
|
||||
// the .dst() becomes the last fused tensor
|
||||
node.add_fused(gf->nodes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
nodes.push_back(std::move(node));
|
||||
}
|
||||
|
||||
#if 1
|
||||
// reorder to improve concurrency
|
||||
const auto order = ggml_metal_graph_optimize_reorder(nodes);
|
||||
#else
|
||||
std::vector<int> order(nodes.size());
|
||||
for (size_t i = 0; i < nodes.size(); i++) {
|
||||
order[i] = i;
|
||||
}
|
||||
#endif
|
||||
|
||||
// unfuse
|
||||
{
|
||||
int j = 0;
|
||||
for (const auto i : order) {
|
||||
const auto & node = nodes[i];
|
||||
|
||||
gf->nodes[j++] = node.node;
|
||||
|
||||
for (auto * fused : node.fused) {
|
||||
gf->nodes[j++] = fused;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
52
ggml/src/ggml-metal/ggml-metal-common.h
Normal file
52
ggml/src/ggml-metal/ggml-metal-common.h
Normal file
@@ -0,0 +1,52 @@
|
||||
// helper functions for ggml-metal that are too difficult to implement in Objective-C
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct ggml_tensor;
|
||||
struct ggml_cgraph;
|
||||
|
||||
enum ggml_mem_range_type {
|
||||
MEM_RANGE_TYPE_SRC = 0,
|
||||
MEM_RANGE_TYPE_DST = 1,
|
||||
};
|
||||
|
||||
// a helper object that can be used for reordering operations to improve concurrency
|
||||
//
|
||||
// the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they
|
||||
// don't write to a memory that is being read by another task or written to by another task in the set
|
||||
//
|
||||
// with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task
|
||||
// can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the
|
||||
// tasks already in the set)
|
||||
//
|
||||
typedef struct ggml_mem_ranges * ggml_mem_ranges_t;
|
||||
|
||||
ggml_mem_ranges_t ggml_mem_ranges_init(int debug);
|
||||
void ggml_mem_ranges_free(ggml_mem_ranges_t mrs);
|
||||
|
||||
// remove all ranges from the set
|
||||
void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs);
|
||||
|
||||
// add src or dst ranges to track
|
||||
bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor);
|
||||
|
||||
// return false if:
|
||||
// - new src range overlaps with any existing dst range
|
||||
// - new dst range overlaps with any existing range (src or dst)
|
||||
bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor);
|
||||
|
||||
// reorder the nodes in the graph to improve concurrency, while respecting fusion
|
||||
//
|
||||
// note: this implementation is generic and not specific to metal
|
||||
// if it proves to work well, we can start using it for other backends in the future
|
||||
void ggml_graph_optimize(struct ggml_cgraph * gf);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
33
ggml/src/ggml-metal/ggml-metal-context.h
Normal file
33
ggml/src/ggml-metal/ggml-metal-context.h
Normal file
@@ -0,0 +1,33 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml-metal-device.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
//
|
||||
// backend context
|
||||
//
|
||||
|
||||
typedef struct ggml_metal * ggml_metal_t;
|
||||
|
||||
ggml_metal_t ggml_metal_init(ggml_metal_device_t dev);
|
||||
void ggml_metal_free(ggml_metal_t ctx);
|
||||
|
||||
void ggml_metal_synchronize(ggml_metal_t ctx);
|
||||
|
||||
void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
|
||||
enum ggml_status ggml_metal_graph_compute (ggml_metal_t ctx, struct ggml_cgraph * gf);
|
||||
void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf);
|
||||
|
||||
void ggml_metal_set_n_cb (ggml_metal_t ctx, int n_cb);
|
||||
void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data);
|
||||
bool ggml_metal_supports_family (ggml_metal_t ctx, int family);
|
||||
void ggml_metal_capture_next_compute(ggml_metal_t ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
575
ggml/src/ggml-metal/ggml-metal-context.m
Normal file
575
ggml/src/ggml-metal/ggml-metal-context.m
Normal file
@@ -0,0 +1,575 @@
|
||||
#import "ggml-metal-context.h"
|
||||
|
||||
#import "ggml-impl.h"
|
||||
#import "ggml-backend-impl.h"
|
||||
|
||||
#import "ggml-metal-impl.h"
|
||||
#import "ggml-metal-common.h"
|
||||
#import "ggml-metal-ops.h"
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#import <Metal/Metal.h>
|
||||
|
||||
#undef MIN
|
||||
#undef MAX
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
// max number of MTLCommandBuffer used to submit a graph for processing
|
||||
#define GGML_METAL_MAX_COMMAND_BUFFERS 8
|
||||
|
||||
struct ggml_metal_command_buffer {
|
||||
id<MTLCommandBuffer> obj;
|
||||
};
|
||||
|
||||
struct ggml_metal {
|
||||
id<MTLDevice> device;
|
||||
id<MTLCommandQueue> queue; // currently a pointer to the device queue, but might become separate queue [TAG_QUEUE_PER_BACKEND]
|
||||
|
||||
ggml_metal_device_t dev;
|
||||
ggml_metal_library_t lib;
|
||||
|
||||
dispatch_queue_t d_queue;
|
||||
|
||||
// additional, inference-time compiled pipelines
|
||||
ggml_metal_pipelines_t pipelines_ext;
|
||||
|
||||
bool use_bfloat;
|
||||
bool use_fusion;
|
||||
bool use_concurrency;
|
||||
bool use_graph_optimize;
|
||||
|
||||
int debug_graph;
|
||||
int debug_fusion;
|
||||
|
||||
// how many times a given op was fused
|
||||
uint64_t fuse_cnt[GGML_OP_COUNT];
|
||||
|
||||
// capture state
|
||||
bool capture_next_compute;
|
||||
bool capture_started;
|
||||
|
||||
id<MTLCaptureScope> capture_scope;
|
||||
|
||||
// command buffer state
|
||||
int n_cb; // number of extra threads used to submit the command buffers
|
||||
int n_nodes_0; // number of nodes submitted by the main thread
|
||||
int n_nodes_1; // remaining number of nodes submitted by the n_cb threads
|
||||
int n_nodes_per_cb;
|
||||
|
||||
struct ggml_cgraph * gf;
|
||||
|
||||
// the callback given to the thread pool
|
||||
void (^encode_async)(size_t ith);
|
||||
|
||||
// n_cb command buffers + 1 used by the main thread
|
||||
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
||||
|
||||
// extra command buffers for things like getting, setting and copying tensors
|
||||
NSMutableArray * cmd_bufs_ext;
|
||||
|
||||
// the last command buffer queued into the Metal queue with operations relevant to the current Metal backend
|
||||
id<MTLCommandBuffer> cmd_buf_last;
|
||||
|
||||
// abort ggml_metal_graph_compute if callback returns true
|
||||
ggml_abort_callback abort_callback;
|
||||
void * abort_callback_data;
|
||||
};
|
||||
|
||||
ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
||||
GGML_LOG_INFO("%s: allocating\n", __func__);
|
||||
|
||||
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
|
||||
// Show all the Metal device instances in the system
|
||||
NSArray * devices = MTLCopyAllDevices();
|
||||
for (id<MTLDevice> device in devices) {
|
||||
GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
|
||||
}
|
||||
[devices release]; // since it was created by a *Copy* C method
|
||||
#endif
|
||||
|
||||
// init context
|
||||
ggml_metal_t res = calloc(1, sizeof(struct ggml_metal));
|
||||
|
||||
res->device = ggml_metal_device_get_obj(dev);
|
||||
|
||||
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[res->device name] UTF8String]);
|
||||
|
||||
// TODO: would it be better to have one queue for the backend and one queue for the device?
|
||||
// the graph encoders and async ops would use the backend queue while the sync ops would use the device queue?
|
||||
//res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND]
|
||||
res->queue = ggml_metal_device_get_queue(dev);
|
||||
if (res->queue == nil) {
|
||||
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
res->dev = dev;
|
||||
res->lib = ggml_metal_device_get_library(dev);
|
||||
if (res->lib == NULL) {
|
||||
GGML_LOG_WARN("%s: the device does not have a precompiled Metal library - this is unexpected\n", __func__);
|
||||
GGML_LOG_WARN("%s: will try to compile it on the fly\n", __func__);
|
||||
|
||||
res->lib = ggml_metal_library_init(dev);
|
||||
if (res->lib == NULL) {
|
||||
GGML_LOG_ERROR("%s: error: failed to initialize the Metal library\n", __func__);
|
||||
|
||||
free(res);
|
||||
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
|
||||
|
||||
res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
||||
|
||||
res->use_bfloat = props_dev->has_bfloat;
|
||||
res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
||||
res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
|
||||
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_GRAPH_DEBUG");
|
||||
res->debug_graph = val ? atoi(val) : 0;
|
||||
}
|
||||
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
||||
res->debug_fusion = val ? atoi(val) : 0;
|
||||
}
|
||||
|
||||
res->use_graph_optimize = true;
|
||||
|
||||
if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) {
|
||||
res->use_graph_optimize = false;
|
||||
}
|
||||
|
||||
memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt));
|
||||
|
||||
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, res->use_bfloat ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use fusion = %s\n", __func__, res->use_fusion ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false");
|
||||
|
||||
res->capture_next_compute = false;
|
||||
res->capture_started = false;
|
||||
res->capture_scope = nil;
|
||||
|
||||
res->gf = nil;
|
||||
res->encode_async = nil;
|
||||
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||
res->cmd_bufs[i].obj = nil;
|
||||
}
|
||||
|
||||
res->cmd_bufs_ext = [[NSMutableArray alloc] init];
|
||||
|
||||
res->cmd_buf_last = nil;
|
||||
|
||||
res->pipelines_ext = ggml_metal_pipelines_init();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void ggml_metal_free(ggml_metal_t ctx) {
|
||||
GGML_LOG_INFO("%s: deallocating\n", __func__);
|
||||
|
||||
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||
if (ctx->cmd_bufs[i].obj) {
|
||||
[ctx->cmd_bufs[i].obj release];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int) ctx->cmd_bufs_ext.count; ++i) {
|
||||
if (ctx->cmd_bufs_ext[i]) {
|
||||
[ctx->cmd_bufs_ext[i] release];
|
||||
}
|
||||
}
|
||||
|
||||
[ctx->cmd_bufs_ext removeAllObjects];
|
||||
[ctx->cmd_bufs_ext release];
|
||||
|
||||
if (ctx->pipelines_ext) {
|
||||
ggml_metal_pipelines_free(ctx->pipelines_ext);
|
||||
ctx->pipelines_ext = nil;
|
||||
}
|
||||
|
||||
if (ctx->debug_fusion > 0) {
|
||||
GGML_LOG_DEBUG("%s: fusion stats:\n", __func__);
|
||||
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
||||
if (ctx->fuse_cnt[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// note: cannot use ggml_log here
|
||||
GGML_LOG_DEBUG("%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Block_release(ctx->encode_async);
|
||||
|
||||
//[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND]
|
||||
|
||||
dispatch_release(ctx->d_queue);
|
||||
|
||||
free(ctx);
|
||||
}
|
||||
|
||||
void ggml_metal_synchronize(ggml_metal_t ctx) {
|
||||
// wait for any backend operations to finish
|
||||
if (ctx->cmd_buf_last) {
|
||||
[ctx->cmd_buf_last waitUntilCompleted];
|
||||
ctx->cmd_buf_last = nil;
|
||||
}
|
||||
|
||||
// release any completed command buffers
|
||||
if (ctx->cmd_bufs_ext.count > 0) {
|
||||
for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) {
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs_ext[i];
|
||||
|
||||
MTLCommandBufferStatus status = [cmd_buf status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status);
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
[cmd_buf release];
|
||||
}
|
||||
|
||||
[ctx->cmd_bufs_ext removeAllObjects];
|
||||
}
|
||||
}
|
||||
|
||||
static struct ggml_metal_buffer_id ggml_metal_get_buffer_id(const struct ggml_tensor * t) {
|
||||
if (!t) {
|
||||
return (struct ggml_metal_buffer_id) { nil, 0 };
|
||||
}
|
||||
|
||||
ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
|
||||
|
||||
return ggml_metal_buffer_get_id(buffer->context, t);
|
||||
}
|
||||
|
||||
void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
@autoreleasepool {
|
||||
// wrap the source data into a Metal buffer
|
||||
id<MTLBuffer> buf_src = [ctx->device newBufferWithBytes:data
|
||||
length:size
|
||||
options:MTLResourceStorageModeShared];
|
||||
|
||||
struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(tensor);
|
||||
if (bid_dst.metal == nil) {
|
||||
GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
|
||||
}
|
||||
|
||||
bid_dst.offs += offset;
|
||||
|
||||
// queue the copy operation into the queue of the Metal context
|
||||
// this will be queued at the end, after any currently ongoing GPU operations
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||
|
||||
[encoder copyFromBuffer:buf_src
|
||||
sourceOffset:0
|
||||
toBuffer:bid_dst.metal
|
||||
destinationOffset:bid_dst.offs
|
||||
size:size];
|
||||
|
||||
[encoder endEncoding];
|
||||
[cmd_buf commit];
|
||||
|
||||
// do not wait here for completion
|
||||
//[cmd_buf waitUntilCompleted];
|
||||
|
||||
// instead, remember a reference to the command buffer and wait for it later if needed
|
||||
[ctx->cmd_bufs_ext addObject:cmd_buf];
|
||||
ctx->cmd_buf_last = cmd_buf;
|
||||
|
||||
[cmd_buf retain];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
@autoreleasepool {
|
||||
id<MTLBuffer> buf_dst = [ctx->device newBufferWithBytesNoCopy:data
|
||||
length:size
|
||||
options:MTLResourceStorageModeShared
|
||||
deallocator:nil];
|
||||
|
||||
struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(tensor);
|
||||
if (bid_src.metal == nil) {
|
||||
GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
|
||||
}
|
||||
|
||||
bid_src.offs += offset;
|
||||
|
||||
// queue the copy operation into the queue of the Metal context
|
||||
// this will be queued at the end, after any currently ongoing GPU operations
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||
|
||||
[encoder copyFromBuffer:bid_src.metal
|
||||
sourceOffset:bid_src.offs
|
||||
toBuffer:buf_dst
|
||||
destinationOffset:0
|
||||
size:size];
|
||||
|
||||
[encoder endEncoding];
|
||||
[cmd_buf commit];
|
||||
|
||||
// do not wait here for completion
|
||||
//[cmd_buf waitUntilCompleted];
|
||||
|
||||
// instead, remember a reference to the command buffer and wait for it later if needed
|
||||
[ctx->cmd_bufs_ext addObject:cmd_buf];
|
||||
ctx->cmd_buf_last = cmd_buf;
|
||||
|
||||
[cmd_buf retain];
|
||||
}
|
||||
}
|
||||
|
||||
enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {
|
||||
// number of nodes encoded by the main thread (empirically determined)
|
||||
const int n_main = 64;
|
||||
|
||||
// number of threads in addition to the main thread
|
||||
const int n_cb = ctx->n_cb;
|
||||
|
||||
// submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
|
||||
// the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
|
||||
// while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
|
||||
// each thread creates it's own command buffer and enqueues the ops in parallel
|
||||
//
|
||||
// tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
|
||||
|
||||
@autoreleasepool {
|
||||
ctx->gf = gf;
|
||||
|
||||
ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
|
||||
ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
|
||||
|
||||
ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
|
||||
|
||||
const bool use_capture = ctx->capture_next_compute;
|
||||
if (use_capture) {
|
||||
ctx->capture_next_compute = false;
|
||||
|
||||
// make sure all previous computations have finished before starting the capture
|
||||
if (ctx->cmd_buf_last) {
|
||||
[ctx->cmd_buf_last waitUntilCompleted];
|
||||
ctx->cmd_buf_last = nil;
|
||||
}
|
||||
|
||||
if (!ctx->capture_started) {
|
||||
// create capture scope
|
||||
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
|
||||
|
||||
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
||||
descriptor.captureObject = ctx->capture_scope;
|
||||
descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
|
||||
descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
|
||||
|
||||
NSError * error = nil;
|
||||
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
|
||||
GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
|
||||
} else {
|
||||
[ctx->capture_scope beginScope];
|
||||
ctx->capture_started = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// the main thread commits the first few commands immediately
|
||||
// cmd_buf[n_cb]
|
||||
{
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
[cmd_buf retain];
|
||||
|
||||
if (ctx->cmd_bufs[n_cb].obj) {
|
||||
[ctx->cmd_bufs[n_cb].obj release];
|
||||
}
|
||||
ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
||||
|
||||
[cmd_buf enqueue];
|
||||
|
||||
ctx->encode_async(n_cb);
|
||||
}
|
||||
|
||||
// remember the command buffer for the next iteration
|
||||
ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj;
|
||||
|
||||
// prepare the rest of the command buffers asynchronously (optional)
|
||||
// cmd_buf[0.. n_cb)
|
||||
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
[cmd_buf retain];
|
||||
|
||||
if (ctx->cmd_bufs[cb_idx].obj) {
|
||||
[ctx->cmd_bufs[cb_idx].obj release];
|
||||
}
|
||||
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
|
||||
|
||||
// always enqueue the first two command buffers
|
||||
// enqueue all of the command buffers if we don't need to abort
|
||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||
[cmd_buf enqueue];
|
||||
|
||||
// update the pointer to the last queued command buffer
|
||||
// this is needed to implement synchronize()
|
||||
ctx->cmd_buf_last = cmd_buf;
|
||||
}
|
||||
}
|
||||
|
||||
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
|
||||
|
||||
// for debugging: block until graph is computed
|
||||
//[ctx->cmd_buf_last waitUntilCompleted];
|
||||
|
||||
// enter here only when capturing in order to wait for all computation to finish
|
||||
// otherwise, we leave the graph to compute asynchronously
|
||||
if (!use_capture && ctx->capture_started) {
|
||||
// wait for completion and check status of each command buffer
|
||||
// needed to detect if the device ran out-of-memory for example (#1881)
|
||||
{
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
|
||||
[cmd_buf waitUntilCompleted];
|
||||
|
||||
MTLCommandBufferStatus status = [cmd_buf status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_cb; ++i) {
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
|
||||
[cmd_buf waitUntilCompleted];
|
||||
|
||||
MTLCommandBufferStatus status = [cmd_buf status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
|
||||
if (!next_buffer) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
|
||||
if (next_queued) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
|
||||
GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
|
||||
return GGML_STATUS_ABORTED;
|
||||
}
|
||||
|
||||
[next_buffer commit];
|
||||
}
|
||||
|
||||
[ctx->capture_scope endScope];
|
||||
[[MTLCaptureManager sharedCaptureManager] stopCapture];
|
||||
}
|
||||
}
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) {
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
if (ctx->use_graph_optimize) {
|
||||
ggml_graph_optimize(gf);
|
||||
}
|
||||
|
||||
//printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0);
|
||||
}
|
||||
|
||||
void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
|
||||
if (ctx->n_cb != n_cb) {
|
||||
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
|
||||
|
||||
if (ctx->n_cb > 2) {
|
||||
GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx->encode_async) {
|
||||
Block_release(ctx->encode_async);
|
||||
}
|
||||
|
||||
ctx->encode_async = Block_copy(^(size_t iter) {
|
||||
const int cb_idx = iter;
|
||||
const int n_cb_l = ctx->n_cb;
|
||||
|
||||
const int n_nodes_0 = ctx->n_nodes_0;
|
||||
const int n_nodes_1 = ctx->n_nodes_1;
|
||||
|
||||
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
||||
|
||||
int idx_start = 0;
|
||||
int idx_end = n_nodes_0;
|
||||
|
||||
if (cb_idx < n_cb_l) {
|
||||
idx_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
|
||||
idx_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
||||
|
||||
ggml_metal_op_t ctx_op = ggml_metal_op_init(
|
||||
ctx->dev,
|
||||
cmd_buf,
|
||||
ctx->gf,
|
||||
idx_start,
|
||||
idx_end,
|
||||
ctx->use_fusion,
|
||||
ctx->use_concurrency,
|
||||
ctx->capture_next_compute,
|
||||
ctx->debug_graph,
|
||||
ctx->debug_fusion);
|
||||
|
||||
for (int idx = idx_start; idx < idx_end;) {
|
||||
const int res = ggml_metal_op_encode(ctx_op, idx);
|
||||
if (res == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
idx += res;
|
||||
}
|
||||
|
||||
ggml_metal_op_free(ctx_op);
|
||||
|
||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||
[cmd_buf commit];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_metal_set_abort_callback(ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data) {
|
||||
ctx->abort_callback = abort_callback;
|
||||
ctx->abort_callback_data = user_data;
|
||||
}
|
||||
|
||||
bool ggml_metal_supports_family(ggml_metal_t ctx, int family) {
|
||||
GGML_ASSERT(ctx->device != nil);
|
||||
|
||||
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
||||
}
|
||||
|
||||
void ggml_metal_capture_next_compute(ggml_metal_t ctx) {
|
||||
ctx->capture_next_compute = true;
|
||||
}
|
||||
1372
ggml/src/ggml-metal/ggml-metal-device.cpp
Normal file
1372
ggml/src/ggml-metal/ggml-metal-device.cpp
Normal file
File diff suppressed because it is too large
Load Diff
227
ggml/src/ggml-metal/ggml-metal-device.h
Normal file
227
ggml/src/ggml-metal/ggml-metal-device.h
Normal file
@@ -0,0 +1,227 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct ggml_metal_buffer_id {
|
||||
void * metal; // id<MTLBuffer>
|
||||
size_t offs;
|
||||
};
|
||||
|
||||
typedef struct ggml_metal_device * ggml_metal_device_t;
|
||||
|
||||
//
|
||||
// MTLFunctionConstantValues wrapper
|
||||
//
|
||||
|
||||
typedef struct ggml_metal_cv * ggml_metal_cv_t;
|
||||
|
||||
ggml_metal_cv_t ggml_metal_cv_init(void);
|
||||
void ggml_metal_cv_free(ggml_metal_cv_t cv);
|
||||
|
||||
void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx);
|
||||
void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx);
|
||||
void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool value, int32_t idx);
|
||||
|
||||
//
|
||||
// MTLComputePipelineState wrapper
|
||||
//
|
||||
|
||||
typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t;
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_pipeline_init(void);
|
||||
void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg);
|
||||
int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0);
|
||||
int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1);
|
||||
int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem);
|
||||
size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
// a collection of pipelines
|
||||
typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t;
|
||||
|
||||
ggml_metal_pipelines_t ggml_metal_pipelines_init(void);
|
||||
void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls);
|
||||
|
||||
void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline);
|
||||
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name);
|
||||
|
||||
//
|
||||
// MTLCommandBuffer wrapper
|
||||
//
|
||||
|
||||
typedef void * ggml_metal_cmd_buf_t;
|
||||
|
||||
//
|
||||
// MTLComputeCommandEncoder wrapper
|
||||
//
|
||||
|
||||
typedef struct ggml_metal_encoder * ggml_metal_encoder_t;
|
||||
|
||||
ggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_cmd_buf_t cmd_buf_raw, bool concurrent);
|
||||
void ggml_metal_encoder_free(ggml_metal_encoder_t encoder);
|
||||
|
||||
void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name);
|
||||
void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder);
|
||||
|
||||
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx);
|
||||
void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx);
|
||||
|
||||
void ggml_metal_encoder_set_threadgroup_memory_size(ggml_metal_encoder_t encoder, size_t size, int idx);
|
||||
|
||||
void ggml_metal_encoder_dispatch_threadgroups(ggml_metal_encoder_t encoder, int tg0, int tg1, int tg2, int tptg0, int tptg1, int tptg2);
|
||||
|
||||
void ggml_metal_encoder_memory_barrier(ggml_metal_encoder_t encoder);
|
||||
|
||||
void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder);
|
||||
|
||||
//
|
||||
// MTLLibrary wrapper
|
||||
//
|
||||
|
||||
typedef struct ggml_metal_library * ggml_metal_library_t;
|
||||
|
||||
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev);
|
||||
void ggml_metal_library_free(ggml_metal_library_t lib);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
|
||||
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tdst);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
int32_t nsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
int32_t nsg,
|
||||
int32_t nwg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
int32_t dv,
|
||||
int32_t nwg);
|
||||
|
||||
//
|
||||
// device
|
||||
//
|
||||
|
||||
struct ggml_metal_device_props {
|
||||
char name[128];
|
||||
|
||||
size_t max_buffer_size;
|
||||
size_t max_working_set_size;
|
||||
size_t max_theadgroup_memory_size;
|
||||
|
||||
bool has_simdgroup_reduction;
|
||||
bool has_simdgroup_mm;
|
||||
bool has_unified_memory;
|
||||
bool has_bfloat;
|
||||
bool use_residency_sets;
|
||||
bool use_shared_buffers;
|
||||
|
||||
bool supports_gpu_family_apple7;
|
||||
};
|
||||
|
||||
ggml_metal_device_t ggml_metal_device_init(void);
|
||||
void ggml_metal_device_free(ggml_metal_device_t dev);
|
||||
|
||||
// return a singleton that is automatically destroyed when the program exits
|
||||
ggml_metal_device_t ggml_metal_device_get(void);
|
||||
|
||||
void * ggml_metal_device_get_obj (ggml_metal_device_t dev); // id<MTLDevice>
|
||||
void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id<MTLCommandQueue>
|
||||
|
||||
ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev);
|
||||
|
||||
void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total);
|
||||
bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op);
|
||||
|
||||
const struct ggml_metal_device_props * ggml_metal_device_get_props(ggml_metal_device_t dev);
|
||||
|
||||
//
|
||||
// device buffers
|
||||
//
|
||||
|
||||
typedef struct ggml_metal_buffer * ggml_metal_buffer_t;
|
||||
|
||||
ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared);
|
||||
ggml_metal_buffer_t ggml_metal_buffer_map (ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size);
|
||||
|
||||
void ggml_metal_buffer_free (ggml_metal_buffer_t buf);
|
||||
void * ggml_metal_buffer_get_base (ggml_metal_buffer_t buf);
|
||||
bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf);
|
||||
|
||||
void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
|
||||
void ggml_metal_buffer_set_tensor (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
void ggml_metal_buffer_get_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
void ggml_metal_buffer_clear (ggml_metal_buffer_t buf, uint8_t value);
|
||||
|
||||
// finds the Metal buffer that contains the tensor data on the GPU device
|
||||
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
||||
// Metal buffer based on the host memory pointer
|
||||
//
|
||||
struct ggml_metal_buffer_id ggml_metal_buffer_get_id(ggml_metal_buffer_t buf, const struct ggml_tensor * t);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
1303
ggml/src/ggml-metal/ggml-metal-device.m
Normal file
1303
ggml/src/ggml-metal/ggml-metal-device.m
Normal file
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,9 @@
|
||||
//
|
||||
// TODO: for optimal performance, become function of the device and work size
|
||||
|
||||
#define N_R0_F 2
|
||||
#define N_SG_F 4
|
||||
|
||||
#define N_R0_Q4_0 4
|
||||
#define N_SG_Q4_0 2
|
||||
|
||||
@@ -20,8 +23,8 @@
|
||||
#define N_R0_Q5_1 4
|
||||
#define N_SG_Q5_1 2
|
||||
|
||||
#define N_R0_Q8_0 4
|
||||
#define N_SG_Q8_0 2
|
||||
#define N_R0_Q8_0 2
|
||||
#define N_SG_Q8_0 4
|
||||
|
||||
#define N_R0_MXFP4 2
|
||||
#define N_SG_MXFP4 2
|
||||
@@ -32,13 +35,13 @@
|
||||
#define N_R0_Q3_K 2
|
||||
#define N_SG_Q3_K 2
|
||||
|
||||
#define N_R0_Q4_K 4
|
||||
#define N_R0_Q4_K 2
|
||||
#define N_SG_Q4_K 2
|
||||
|
||||
#define N_R0_Q5_K 2
|
||||
#define N_SG_Q5_K 2
|
||||
|
||||
#define N_R0_Q6_K 1
|
||||
#define N_R0_Q6_K 2
|
||||
#define N_SG_Q6_K 2
|
||||
|
||||
#define N_R0_IQ1_S 4
|
||||
@@ -68,6 +71,12 @@
|
||||
#define N_R0_IQ4_XS 2
|
||||
#define N_SG_IQ4_XS 2
|
||||
|
||||
// function constants offsets
|
||||
#define FC_FLASH_ATTN_EXT 100
|
||||
#define FC_FLASH_ATTN_EXT_VEC 200
|
||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
|
||||
#define FC_MUL_MV 400
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
||||
@@ -160,6 +169,16 @@ typedef struct {
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_repeat;
|
||||
|
||||
typedef struct {
|
||||
float scale;
|
||||
float bias;
|
||||
} ggml_metal_kargs_scale;
|
||||
|
||||
typedef struct {
|
||||
float min;
|
||||
float max;
|
||||
} ggml_metal_kargs_clamp;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
@@ -236,9 +255,11 @@ typedef struct {
|
||||
int32_t ne11;
|
||||
int32_t ne_12_2; // assume K and V are same shape
|
||||
int32_t ne_12_3;
|
||||
int32_t ns10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ns20;
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
@@ -258,10 +279,43 @@ typedef struct {
|
||||
float logit_softcap;
|
||||
} ggml_metal_kargs_flash_attn_ext;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne11;
|
||||
int32_t ne_12_2; // assume K and V are same shape
|
||||
int32_t ne_12_3;
|
||||
int32_t ns10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ns20;
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
uint64_t nb32;
|
||||
uint64_t nb33;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
int32_t n_head_log2;
|
||||
float logit_softcap;
|
||||
} ggml_metal_kargs_flash_attn_ext_vec;
|
||||
|
||||
typedef struct {
|
||||
int32_t nrows;
|
||||
int32_t ne20;
|
||||
} ggml_metal_kargs_flash_attn_ext_reduce;
|
||||
} ggml_metal_kargs_flash_attn_ext_vec_reduce;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
@@ -320,9 +374,6 @@ typedef struct {
|
||||
int32_t ne1;
|
||||
int16_t r2;
|
||||
int16_t r3;
|
||||
int16_t nsg;
|
||||
int16_t nxpsg;
|
||||
int16_t r1ptg;
|
||||
} ggml_metal_kargs_mul_mv_ext;
|
||||
|
||||
typedef struct {
|
||||
@@ -413,7 +464,7 @@ typedef struct {
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int32_t n_groups;
|
||||
int32_t ngrp;
|
||||
float eps;
|
||||
} ggml_metal_kargs_group_norm;
|
||||
|
||||
@@ -466,14 +517,6 @@ typedef struct {
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int64_t ne10;
|
||||
int64_t ne11;
|
||||
int64_t ne12;
|
||||
int64_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int64_t ne0;
|
||||
int64_t ne1;
|
||||
int64_t ne2;
|
||||
@@ -507,12 +550,6 @@ typedef struct {
|
||||
int32_t n_head_log2;
|
||||
} ggml_metal_kargs_soft_max;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int n_past;
|
||||
} ggml_metal_kargs_diag_mask_inf;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
@@ -539,7 +576,7 @@ typedef struct {
|
||||
int64_t n_group;
|
||||
int64_t n_seq_tokens;
|
||||
int64_t n_seqs;
|
||||
int64_t s_off;
|
||||
uint64_t s_off;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
@@ -679,7 +716,12 @@ typedef struct {
|
||||
int64_t IW;
|
||||
int64_t OH;
|
||||
int64_t OW;
|
||||
int64_t parallel_elements;
|
||||
int64_t np;
|
||||
} ggml_metal_kargs_pool_2d;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
uint64_t nb01;
|
||||
} ggml_metal_kargs_argmax;
|
||||
|
||||
#endif // GGML_METAL_IMPL
|
||||
|
||||
3191
ggml/src/ggml-metal/ggml-metal-ops.cpp
Normal file
3191
ggml/src/ggml-metal/ggml-metal-ops.cpp
Normal file
File diff suppressed because it is too large
Load Diff
81
ggml/src/ggml-metal/ggml-metal-ops.h
Normal file
81
ggml/src/ggml-metal/ggml-metal-ops.h
Normal file
@@ -0,0 +1,81 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml-metal-device.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct ggml_metal_op * ggml_metal_op_t;
|
||||
|
||||
ggml_metal_op_t ggml_metal_op_init(
|
||||
ggml_metal_device_t dev,
|
||||
ggml_metal_cmd_buf_t cmd_buf,
|
||||
struct ggml_cgraph * gf,
|
||||
int idx_start,
|
||||
int idx_end,
|
||||
bool use_fusion,
|
||||
bool use_concurrency,
|
||||
bool use_capture,
|
||||
int debug_graph,
|
||||
int debug_fusion);
|
||||
|
||||
void ggml_metal_op_free(ggml_metal_op_t ctx);
|
||||
|
||||
int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx);
|
||||
|
||||
//
|
||||
// available ops:
|
||||
//
|
||||
|
||||
// tokens per expert
|
||||
size_t ggml_metal_op_mul_mat_id_extra_tpe(const struct ggml_tensor * op);
|
||||
|
||||
// id map [n_tokens, n_expert]
|
||||
size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
|
||||
|
||||
// return true if we should use the FA vector kernel for this op
|
||||
bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
|
||||
|
||||
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rms_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
718
ggml/src/ggml-metal/ggml-metal.cpp
Normal file
718
ggml/src/ggml-metal/ggml-metal.cpp
Normal file
@@ -0,0 +1,718 @@
|
||||
#include "ggml-metal.h"
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#include "ggml-metal-device.h"
|
||||
#include "ggml-metal-context.h"
|
||||
#include "ggml-metal-ops.h"
|
||||
|
||||
// globals
|
||||
|
||||
// initialized in ggml_backend_metal_reg
|
||||
static ggml_backend_reg g_ggml_metal_reg;
|
||||
static ggml_backend_device g_ggml_metal_device;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// backend interface
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// shared buffer
|
||||
|
||||
static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_free(ctx);
|
||||
}
|
||||
|
||||
static void * ggml_backend_metal_buffer_shared_get_base(ggml_backend_buffer_t buffer) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
return ggml_metal_buffer_get_base(ctx);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_shared_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_shared_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_shared_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_clear(ctx, value);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = {
|
||||
/* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer,
|
||||
/* .get_base = */ ggml_backend_metal_buffer_shared_get_base,
|
||||
/* .init_tensor = */ NULL,
|
||||
/* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor,
|
||||
/* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_metal_buffer_shared_clear,
|
||||
/* .reset = */ NULL,
|
||||
};
|
||||
|
||||
// private buffer
|
||||
|
||||
static void ggml_backend_metal_buffer_private_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_free(ctx);
|
||||
}
|
||||
|
||||
static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
return ggml_metal_buffer_get_base(ctx);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_private_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_private_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_private_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_clear(ctx, value);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = {
|
||||
/* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer,
|
||||
/* .get_base = */ ggml_backend_metal_buffer_private_get_base,
|
||||
/* .init_tensor = */ NULL,
|
||||
/* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor,
|
||||
/* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_metal_buffer_private_clear,
|
||||
/* .reset = */ NULL,
|
||||
};
|
||||
|
||||
//
|
||||
// buffer types
|
||||
//
|
||||
|
||||
// common method for allocating shread or private Metal buffers
|
||||
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
|
||||
ggml_metal_buffer_t res = ggml_metal_buffer_init(ctx_dev, size, shared);
|
||||
|
||||
ggml_backend_buffer_i buf_i = ggml_metal_buffer_is_shared(res)
|
||||
? ggml_backend_metal_buffer_shared_i
|
||||
: ggml_backend_metal_buffer_private_i;
|
||||
|
||||
return ggml_backend_buffer_init(buft, buf_i, res, size);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||
size_t res = ggml_nbytes(tensor);
|
||||
|
||||
// some operations require additional memory for fleeting data:
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
res += ggml_metal_op_mul_mat_id_extra_tpe(tensor);
|
||||
res += ggml_metal_op_mul_mat_id_extra_ids(tensor);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) {
|
||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return res;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
// default (shared) buffer type
|
||||
|
||||
static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) {
|
||||
return "Metal";
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_shared_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
return 32;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
|
||||
|
||||
return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) {
|
||||
return false;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) {
|
||||
static ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name,
|
||||
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment,
|
||||
/* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size,
|
||||
/* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,
|
||||
/* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host,
|
||||
},
|
||||
/* .device = */ &g_ggml_metal_device,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
return &ggml_backend_buffer_type_metal;
|
||||
}
|
||||
|
||||
// default (private) buffer type
|
||||
|
||||
static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) {
|
||||
return "Metal_Private";
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, false);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_private_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
return 32;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
|
||||
|
||||
return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_private_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) {
|
||||
return false;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) {
|
||||
static ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_metal_buffer_type_private_get_name,
|
||||
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment,
|
||||
/* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size,
|
||||
/* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size,
|
||||
/* .is_host = */ ggml_backend_metal_buffer_type_private_is_host,
|
||||
},
|
||||
/* .device = */ &g_ggml_metal_device,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
return &ggml_backend_buffer_type_metal;
|
||||
}
|
||||
|
||||
// mapped buffer type
|
||||
|
||||
static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) {
|
||||
return "Metal_Mapped";
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
// for mapped buffers, prefer shared memory
|
||||
return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_mapped_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
return 32;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
|
||||
|
||||
return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_mapped_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) {
|
||||
return false;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) {
|
||||
// note: not obvious, but this buffer type still needs to implement .alloc_buffer:
|
||||
// https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099
|
||||
static ggml_backend_buffer_type ggml_backend_buffer_type_mapped_metal = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name,
|
||||
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment,
|
||||
/* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size,
|
||||
/* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size,
|
||||
/* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host,
|
||||
},
|
||||
/* .device = */ &g_ggml_metal_device,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
return &ggml_backend_buffer_type_mapped_metal;
|
||||
}
|
||||
|
||||
// backend
|
||||
|
||||
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
||||
return "Metal";
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
// wait for any ongoing async operations to finish
|
||||
ggml_metal_synchronize(ctx);
|
||||
|
||||
ggml_metal_free(ctx);
|
||||
|
||||
free(backend);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
ggml_metal_synchronize(ctx);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
ggml_metal_set_tensor_async(ctx, tensor, data, offset, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
ggml_metal_get_tensor_async(ctx, tensor, data, offset, size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
return false;
|
||||
|
||||
GGML_UNUSED(backend_src);
|
||||
GGML_UNUSED(backend_dst);
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
return ggml_metal_graph_compute(ctx, cgraph);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
ggml_metal_graph_optimize(ctx, cgraph);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
ggml_metal_set_n_cb(ctx, n_cb);
|
||||
|
||||
}
|
||||
|
||||
static ggml_backend_i ggml_backend_metal_i = {
|
||||
/* .get_name = */ ggml_backend_metal_name,
|
||||
/* .free = */ ggml_backend_metal_free,
|
||||
/* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
|
||||
/* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
|
||||
/* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups
|
||||
/* .synchronize = */ ggml_backend_metal_synchronize,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
/* .graph_plan_free = */ NULL,
|
||||
/* .graph_plan_update = */ NULL,
|
||||
/* .graph_plan_compute = */ NULL,
|
||||
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
||||
|
||||
// the events API is needed only for multi-GPU setups, so likely no need to implement it for Metal
|
||||
// in any case, these docs seem relevant if we ever decide to implement it:
|
||||
// https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
/* .graph_optimize = */ ggml_backend_metal_graph_optimize,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_metal_guid(void) {
|
||||
static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
|
||||
return &guid;
|
||||
}
|
||||
|
||||
ggml_backend_t ggml_backend_metal_init(void) {
|
||||
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
|
||||
|
||||
ggml_metal_t ctx = ggml_metal_init(ctx_dev);
|
||||
if (ctx == NULL) {
|
||||
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend));
|
||||
|
||||
*backend = {
|
||||
/* .guid = */ ggml_backend_metal_guid(),
|
||||
/* .interface = */ ggml_backend_metal_i,
|
||||
/* .device = */ dev,
|
||||
/* .context = */ ctx,
|
||||
};
|
||||
|
||||
ggml_backend_metal_set_n_cb(backend, 1);
|
||||
|
||||
return backend;
|
||||
}
|
||||
|
||||
bool ggml_backend_is_metal(ggml_backend_t backend) {
|
||||
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
|
||||
}
|
||||
|
||||
void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
|
||||
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
ggml_metal_set_abort_callback(ctx, abort_callback, user_data);
|
||||
}
|
||||
|
||||
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
||||
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
return ggml_metal_supports_family(ctx, family);
|
||||
}
|
||||
|
||||
void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
|
||||
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
ggml_metal_capture_next_compute(ctx);
|
||||
}
|
||||
|
||||
// backend device
|
||||
|
||||
static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
|
||||
return "Metal";
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
|
||||
|
||||
return ggml_metal_device_get_props(ctx_dev)->name;
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
|
||||
|
||||
ggml_metal_device_get_memory(ctx_dev, free, total);
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
|
||||
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_metal_device_get_name(dev);
|
||||
props->description = ggml_backend_metal_device_get_description(dev);
|
||||
props->type = ggml_backend_metal_device_get_type(dev);
|
||||
|
||||
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
|
||||
props->caps = {
|
||||
/* .async = */ true,
|
||||
/* .host_buffer = */ false,
|
||||
/* .buffer_from_host_ptr = */ true,
|
||||
/* .events = */ false,
|
||||
};
|
||||
}
|
||||
|
||||
static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
|
||||
|
||||
ggml_metal_t ctx = ggml_metal_init(ctx_dev);
|
||||
if (ctx == NULL) {
|
||||
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend));
|
||||
|
||||
*backend = {
|
||||
/* .guid = */ ggml_backend_metal_guid(),
|
||||
/* .interface = */ ggml_backend_metal_i,
|
||||
/* .device = */ dev,
|
||||
/* .context = */ ctx,
|
||||
};
|
||||
|
||||
ggml_backend_metal_set_n_cb(backend, 1);
|
||||
|
||||
return backend;
|
||||
|
||||
GGML_UNUSED(params);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
|
||||
|
||||
const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);
|
||||
|
||||
return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared() : ggml_backend_metal_buffer_type_private();
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
|
||||
|
||||
ggml_metal_buffer_t res = ggml_metal_buffer_map(ctx_dev, ptr, size, max_tensor_size);
|
||||
|
||||
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(), ggml_backend_metal_buffer_shared_i, res, size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
|
||||
|
||||
return ggml_metal_device_supports_op(ctx_dev, op);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||
return
|
||||
buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name ||
|
||||
buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name ||
|
||||
buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static int64_t get_op_batch_size(const ggml_tensor * op) {
|
||||
switch (op->op) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
return op->ne[1];
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return op->ne[2];
|
||||
default:
|
||||
return ggml_nrows(op);
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
const int min_batch_size = 32;
|
||||
|
||||
return (op->op == GGML_OP_MUL_MAT ||
|
||||
op->op == GGML_OP_MUL_MAT_ID) &&
|
||||
get_op_batch_size(op) >= min_batch_size;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
GGML_UNUSED(op);
|
||||
}
|
||||
|
||||
static ggml_backend_device_i ggml_backend_metal_device_i = {
|
||||
/* .get_name = */ ggml_backend_metal_device_get_name,
|
||||
/* .get_description = */ ggml_backend_metal_device_get_description,
|
||||
/* .get_memory = */ ggml_backend_metal_device_get_memory,
|
||||
/* .get_type = */ ggml_backend_metal_device_get_type,
|
||||
/* .get_props = */ ggml_backend_metal_device_get_props,
|
||||
/* .init_backend = */ ggml_backend_metal_device_init,
|
||||
/* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type,
|
||||
/* .get_host_buffer_type = */ NULL,
|
||||
/* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped,
|
||||
/* .supports_op = */ ggml_backend_metal_device_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_metal_device_supports_buft,
|
||||
/* .offload_op = */ ggml_backend_metal_device_offload_op,
|
||||
/* .event_new = */ NULL,
|
||||
/* .event_free = */ NULL,
|
||||
/* .event_synchronize = */ NULL,
|
||||
};
|
||||
|
||||
// backend registry
|
||||
|
||||
static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
|
||||
return "Metal";
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
|
||||
return 1;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
|
||||
GGML_ASSERT(index == 0);
|
||||
|
||||
return &g_ggml_metal_device;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
GGML_UNUSED(index);
|
||||
}
|
||||
|
||||
static ggml_backend_feature g_ggml_backend_metal_features[] = {
|
||||
#if defined(GGML_METAL_EMBED_LIBRARY)
|
||||
{ "EMBED_LIBRARY", "1" },
|
||||
#endif
|
||||
{ NULL, NULL },
|
||||
};
|
||||
|
||||
static ggml_backend_feature * ggml_backend_metal_get_features(ggml_backend_reg_t reg) {
|
||||
return g_ggml_backend_metal_features;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
||||
if (strcmp(name, "ggml_backend_get_features") == 0) {
|
||||
return (void *)ggml_backend_metal_get_features;
|
||||
}
|
||||
|
||||
return NULL;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static ggml_backend_reg_i ggml_backend_metal_reg_i = {
|
||||
/* .get_name = */ ggml_backend_metal_reg_get_name,
|
||||
/* .device_count = */ ggml_backend_metal_reg_device_count,
|
||||
/* .device_get = */ ggml_backend_metal_reg_device_get,
|
||||
/* .get_proc_address = */ ggml_backend_metal_get_proc_address,
|
||||
};
|
||||
|
||||
ggml_backend_reg_t ggml_backend_metal_reg(void) {
|
||||
{
|
||||
g_ggml_metal_reg = {
|
||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||
/* .iface = */ ggml_backend_metal_reg_i,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
g_ggml_metal_device = {
|
||||
/* .iface = */ ggml_backend_metal_device_i,
|
||||
/* .reg = */ &g_ggml_metal_reg,
|
||||
/* .context = */ ggml_metal_device_get(),
|
||||
};
|
||||
}
|
||||
|
||||
return &g_ggml_metal_reg;
|
||||
}
|
||||
|
||||
GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -83,8 +83,10 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mv_q4_0_f32_1d_16x_flat
|
||||
mul_mv_q6_k
|
||||
mul_mv_mxfp4_f32
|
||||
mul_mv_mxfp4_f32_flat
|
||||
mul_mv_id_q4_0_f32_8x_flat
|
||||
mul_mv_id_mxfp4_f32
|
||||
mul_mv_id_mxfp4_f32_flat
|
||||
mul_mm_f32_f32_l4_lm
|
||||
mul_mm_f16_f32_l4_lm
|
||||
mul
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user