mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
Compare commits
173 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00fa15fedc | ||
|
|
946b1f6859 | ||
|
|
6c6e397aff | ||
|
|
afc0e89698 | ||
|
|
a5771c9eea | ||
|
|
c35f9eaf09 | ||
|
|
1f45f2890e | ||
|
|
613c5095c3 | ||
|
|
7f97599581 | ||
|
|
bf78f5439e | ||
|
|
bbfc849274 | ||
|
|
ca0ef2dddb | ||
|
|
89d1029559 | ||
|
|
f1a4e72de5 | ||
|
|
4762ad7316 | ||
|
|
1dc9614e06 | ||
|
|
446595b9b3 | ||
|
|
66906cd82a | ||
|
|
11dd5a44eb | ||
|
|
9b8f3c6c77 | ||
|
|
c7f3169cd5 | ||
|
|
793c0d7f46 | ||
|
|
ce111d39d6 | ||
|
|
e7fecba934 | ||
|
|
e2b7621e7c | ||
|
|
c1dbea752a | ||
|
|
749e0d27f0 | ||
|
|
64bf1c3744 | ||
|
|
c12bbde372 | ||
|
|
3f4fc97f1d | ||
|
|
2df255da3c | ||
|
|
60f816a79d | ||
|
|
5592f278b6 | ||
|
|
e4868d16d2 | ||
|
|
820de57d4f | ||
|
|
cb4a63aad6 | ||
|
|
86f5623d90 | ||
|
|
39cffdf188 | ||
|
|
065908cb09 | ||
|
|
4ec6291a24 | ||
|
|
a12363bbf0 | ||
|
|
a86f52b285 | ||
|
|
b284197df4 | ||
|
|
221c0e0c58 | ||
|
|
07a19e27a2 | ||
|
|
18f3b5ff9e | ||
|
|
7233358d29 | ||
|
|
6c88b3bb25 | ||
|
|
14c28dfc50 | ||
|
|
8c988fa41d | ||
|
|
acd6cb1c41 | ||
|
|
84712b6043 | ||
|
|
d4d1522b20 | ||
|
|
d1aa0cc5d1 | ||
|
|
c8ade30036 | ||
|
|
e28c0b80c2 | ||
|
|
8e6f8bc875 | ||
|
|
adef81781a | ||
|
|
48b86c4fdb | ||
|
|
38d3af1b73 | ||
|
|
6c9ee3b17e | ||
|
|
cd465d823c | ||
|
|
922042601b | ||
|
|
2ba1333b35 | ||
|
|
c2e058f1b4 | ||
|
|
c82d48ec23 | ||
|
|
b4efd77f8a | ||
|
|
2be60cbc27 | ||
|
|
b526ad2668 | ||
|
|
938b785764 | ||
|
|
36c153248f | ||
|
|
a979ca22db | ||
|
|
90083283ec | ||
|
|
d4b91ea7b2 | ||
|
|
83f5872404 | ||
|
|
f0d4d176df | ||
|
|
b17230917c | ||
|
|
bf9087f59a | ||
|
|
9fb1042ce6 | ||
|
|
2adf8d83ac | ||
|
|
021cc28bef | ||
|
|
d498af3d5a | ||
|
|
eacdeb5bfc | ||
|
|
e0cb5c5cb8 | ||
|
|
f9a31eea06 | ||
|
|
8f974bc1e9 | ||
|
|
09651d09ff | ||
|
|
349ea79fce | ||
|
|
670e1360cd | ||
|
|
760b4484e3 | ||
|
|
cb887f1bc1 | ||
|
|
d6fb3f6b49 | ||
|
|
01612b7409 | ||
|
|
086cf81e88 | ||
|
|
d9b691081c | ||
|
|
ad57d3edd2 | ||
|
|
1ba45d4982 | ||
|
|
19e5943d9e | ||
|
|
496957e1cb | ||
|
|
21c021745d | ||
|
|
b0f0ecc3dc | ||
|
|
225e7a1438 | ||
|
|
ab14019821 | ||
|
|
64978340b0 | ||
|
|
6ffd4e9c44 | ||
|
|
e4841d24d3 | ||
|
|
538cc77f7f | ||
|
|
5cae766541 | ||
|
|
4b91d6f71f | ||
|
|
cf91f217f1 | ||
|
|
79e0b68c17 | ||
|
|
c81f4192f9 | ||
|
|
4a4f426944 | ||
|
|
ba1ceb3456 | ||
|
|
10a0351a97 | ||
|
|
68e37a61a7 | ||
|
|
cbc68be51d | ||
|
|
bdca38376f | ||
|
|
55c509daf5 | ||
|
|
9c9e4fc635 | ||
|
|
494c5899cb | ||
|
|
0f4c6ec0f1 | ||
|
|
65a3ebb0aa | ||
|
|
0d9226763c | ||
|
|
982e347255 | ||
|
|
923e3ea2e3 | ||
|
|
e743cddb60 | ||
|
|
05fec5bd29 | ||
|
|
dcf7f2ea3c | ||
|
|
84b396e051 | ||
|
|
c31e60647d | ||
|
|
67eade1bf9 | ||
|
|
7de5c7cab6 | ||
|
|
8eff95544e | ||
|
|
3120413ccd | ||
|
|
215535701d | ||
|
|
74bb294591 | ||
|
|
3e303b1107 | ||
|
|
0c1df14b5f | ||
|
|
b3ad3a0191 | ||
|
|
98197e5c98 | ||
|
|
f5e96b368f | ||
|
|
756aa1020a | ||
|
|
aaa088d87f | ||
|
|
0d5375d54b | ||
|
|
576c82eda2 | ||
|
|
0aedae00e6 | ||
|
|
6bdda13981 | ||
|
|
0b8855775c | ||
|
|
4bb625b713 | ||
|
|
11ee0fea2a | ||
|
|
a457551332 | ||
|
|
704bb7a71c | ||
|
|
435a6d10d6 | ||
|
|
f9a867f592 | ||
|
|
ac44eb6c80 | ||
|
|
a57d1bcb3c | ||
|
|
cb9178f885 | ||
|
|
4a5686da22 | ||
|
|
98bab638fb | ||
|
|
26a48ad699 | ||
|
|
ffd59e7d18 | ||
|
|
105554595f | ||
|
|
04655063c4 | ||
|
|
20b7bf8a32 | ||
|
|
6efcd65945 | ||
|
|
699f4392a3 | ||
|
|
08382869a2 | ||
|
|
bb4f7a9e4e | ||
|
|
b8eeb8741d | ||
|
|
17a1f0d2d4 | ||
|
|
8f22dc0a53 | ||
|
|
53903ae6fa |
@@ -22,8 +22,8 @@ AllowShortIfStatementsOnASingleLine: Never
|
||||
AllowShortLambdasOnASingleLine: Inline
|
||||
AllowShortLoopsOnASingleLine: false
|
||||
AlwaysBreakBeforeMultilineStrings: true
|
||||
BinPackArguments: true
|
||||
BinPackParameters: true # OnePerLine
|
||||
BinPackArguments: false
|
||||
BinPackParameters: false # OnePerLine
|
||||
BitFieldColonSpacing: Both
|
||||
BreakBeforeBraces: Custom # Attach
|
||||
BraceWrapping:
|
||||
@@ -70,15 +70,18 @@ ExperimentalAutoDetectBinPacking: false
|
||||
FixNamespaceComments: true
|
||||
IncludeBlocks: Regroup
|
||||
IncludeCategories:
|
||||
- Regex: '^<.*\.h>'
|
||||
- Regex: '".*"'
|
||||
Priority: 1
|
||||
SortPriority: 0
|
||||
- Regex: '^<.*'
|
||||
- Regex: '^<.*\.h>'
|
||||
Priority: 2
|
||||
SortPriority: 0
|
||||
- Regex: '.*'
|
||||
- Regex: '^<.*'
|
||||
Priority: 3
|
||||
SortPriority: 0
|
||||
- Regex: '.*'
|
||||
Priority: 4
|
||||
SortPriority: 0
|
||||
IncludeIsMainRegex: '([-_](test|unittest))?$'
|
||||
IncludeIsMainSourceRegex: ''
|
||||
IndentAccessModifiers: false
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG MUSA_VERSION=rc4.0.1
|
||||
ARG MUSA_VERSION=rc4.2.0
|
||||
# Target the MUSA build image
|
||||
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-devel-ubuntu${UBUNTU_VERSION}
|
||||
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}-amd64
|
||||
|
||||
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-runtime-ubuntu${UBUNTU_VERSION}
|
||||
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}-amd64
|
||||
|
||||
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ let
|
||||
inherit (lib)
|
||||
cmakeBool
|
||||
cmakeFeature
|
||||
optionalAttrs
|
||||
optionals
|
||||
strings
|
||||
;
|
||||
@@ -197,7 +198,7 @@ effectiveStdenv.mkDerivation (finalAttrs: {
|
||||
];
|
||||
|
||||
# Environment variables needed for ROCm
|
||||
env = optionals useRocm {
|
||||
env = optionalAttrs useRocm {
|
||||
ROCM_PATH = "${rocmPackages.clr}";
|
||||
HIP_DEVICE_LIB_PATH = "${rocmPackages.rocm-device-libs}/amdgcn/bitcode";
|
||||
};
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG ROCM_VERSION=6.3
|
||||
ARG AMDGPU_VERSION=6.3
|
||||
ARG ROCM_VERSION=6.4
|
||||
ARG AMDGPU_VERSION=6.4
|
||||
|
||||
# Target the CUDA build image
|
||||
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
|
||||
|
||||
238
.github/workflows/build-linux-cross.yml
vendored
238
.github/workflows/build-linux-cross.yml
vendored
@@ -48,98 +48,98 @@ jobs:
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
ubuntu-24-riscv64-vulkan-cross:
|
||||
runs-on: ubuntu-24.04
|
||||
# ubuntu-24-riscv64-vulkan-cross:
|
||||
# runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Riscv
|
||||
run: |
|
||||
sudo dpkg --add-architecture riscv64
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - name: Setup Riscv
|
||||
# run: |
|
||||
# sudo dpkg --add-architecture riscv64
|
||||
|
||||
# Add arch-specific repositories for non-amd64 architectures
|
||||
cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list
|
||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||
EOF
|
||||
# # Add arch-specific repositories for non-amd64 architectures
|
||||
# cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list
|
||||
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||
# EOF
|
||||
|
||||
sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
# sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
glslc \
|
||||
gcc-14-riscv64-linux-gnu \
|
||||
g++-14-riscv64-linux-gnu \
|
||||
libvulkan-dev:riscv64
|
||||
# sudo apt-get install -y --no-install-recommends \
|
||||
# build-essential \
|
||||
# glslc \
|
||||
# gcc-14-riscv64-linux-gnu \
|
||||
# g++-14-riscv64-linux-gnu \
|
||||
# libvulkan-dev:riscv64
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_CURL=OFF \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_VULKAN=ON \
|
||||
-DGGML_OPENMP=OFF \
|
||||
-DLLAMA_BUILD_EXAMPLES=ON \
|
||||
-DLLAMA_BUILD_TOOLS=ON \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DCMAKE_SYSTEM_NAME=Linux \
|
||||
-DCMAKE_SYSTEM_PROCESSOR=riscv64 \
|
||||
-DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
|
||||
-DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
-DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
# - name: Build
|
||||
# run: |
|
||||
# cmake -B build -DLLAMA_CURL=OFF \
|
||||
# -DCMAKE_BUILD_TYPE=Release \
|
||||
# -DGGML_VULKAN=ON \
|
||||
# -DGGML_OPENMP=OFF \
|
||||
# -DLLAMA_BUILD_EXAMPLES=ON \
|
||||
# -DLLAMA_BUILD_TOOLS=ON \
|
||||
# -DLLAMA_BUILD_TESTS=OFF \
|
||||
# -DCMAKE_SYSTEM_NAME=Linux \
|
||||
# -DCMAKE_SYSTEM_PROCESSOR=riscv64 \
|
||||
# -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
|
||||
# -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
|
||||
# -DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
# -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
# cmake --build build --config Release -j $(nproc)
|
||||
|
||||
ubuntu-24-arm64-vulkan-cross:
|
||||
runs-on: ubuntu-24.04
|
||||
# ubuntu-24-arm64-vulkan-cross:
|
||||
# runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Arm64
|
||||
run: |
|
||||
sudo dpkg --add-architecture arm64
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - name: Setup Arm64
|
||||
# run: |
|
||||
# sudo dpkg --add-architecture arm64
|
||||
|
||||
# Add arch-specific repositories for non-amd64 architectures
|
||||
cat << EOF | sudo tee /etc/apt/sources.list.d/arm64-ports.list
|
||||
deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||
deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||
deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||
deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||
EOF
|
||||
# # Add arch-specific repositories for non-amd64 architectures
|
||||
# cat << EOF | sudo tee /etc/apt/sources.list.d/arm64-ports.list
|
||||
# deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||
# deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||
# deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||
# deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||
# EOF
|
||||
|
||||
sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
# sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
glslc \
|
||||
crossbuild-essential-arm64 \
|
||||
libvulkan-dev:arm64
|
||||
# sudo apt-get install -y --no-install-recommends \
|
||||
# build-essential \
|
||||
# glslc \
|
||||
# crossbuild-essential-arm64 \
|
||||
# libvulkan-dev:arm64
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_CURL=OFF \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_VULKAN=ON \
|
||||
-DGGML_OPENMP=OFF \
|
||||
-DLLAMA_BUILD_EXAMPLES=ON \
|
||||
-DLLAMA_BUILD_TOOLS=ON \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DCMAKE_SYSTEM_NAME=Linux \
|
||||
-DCMAKE_SYSTEM_PROCESSOR=aarch64 \
|
||||
-DCMAKE_C_COMPILER=aarch64-linux-gnu-gcc \
|
||||
-DCMAKE_CXX_COMPILER=aarch64-linux-gnu-g++ \
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
-DCMAKE_FIND_ROOT_PATH=/usr/lib/aarch64-linux-gnu \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
# - name: Build
|
||||
# run: |
|
||||
# cmake -B build -DLLAMA_CURL=OFF \
|
||||
# -DCMAKE_BUILD_TYPE=Release \
|
||||
# -DGGML_VULKAN=ON \
|
||||
# -DGGML_OPENMP=OFF \
|
||||
# -DLLAMA_BUILD_EXAMPLES=ON \
|
||||
# -DLLAMA_BUILD_TOOLS=ON \
|
||||
# -DLLAMA_BUILD_TESTS=OFF \
|
||||
# -DCMAKE_SYSTEM_NAME=Linux \
|
||||
# -DCMAKE_SYSTEM_PROCESSOR=aarch64 \
|
||||
# -DCMAKE_C_COMPILER=aarch64-linux-gnu-gcc \
|
||||
# -DCMAKE_CXX_COMPILER=aarch64-linux-gnu-g++ \
|
||||
# -DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
# -DCMAKE_FIND_ROOT_PATH=/usr/lib/aarch64-linux-gnu \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
# cmake --build build --config Release -j $(nproc)
|
||||
|
||||
ubuntu-24-ppc64el-cpu-cross:
|
||||
runs-on: ubuntu-24.04
|
||||
@@ -185,52 +185,52 @@ jobs:
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
ubuntu-24-ppc64el-vulkan-cross:
|
||||
runs-on: ubuntu-24.04
|
||||
# ubuntu-24-ppc64el-vulkan-cross:
|
||||
# runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup PowerPC64le
|
||||
run: |
|
||||
sudo dpkg --add-architecture ppc64el
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - name: Setup PowerPC64le
|
||||
# run: |
|
||||
# sudo dpkg --add-architecture ppc64el
|
||||
|
||||
# Add arch-specific repositories for non-amd64 architectures
|
||||
cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||
EOF
|
||||
# # Add arch-specific repositories for non-amd64 architectures
|
||||
# cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list
|
||||
# deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||
# deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||
# deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||
# deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||
# EOF
|
||||
|
||||
sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
# sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
glslc \
|
||||
gcc-14-powerpc64le-linux-gnu \
|
||||
g++-14-powerpc64le-linux-gnu \
|
||||
libvulkan-dev:ppc64el
|
||||
# sudo apt-get install -y --no-install-recommends \
|
||||
# build-essential \
|
||||
# glslc \
|
||||
# gcc-14-powerpc64le-linux-gnu \
|
||||
# g++-14-powerpc64le-linux-gnu \
|
||||
# libvulkan-dev:ppc64el
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_CURL=OFF \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_VULKAN=ON \
|
||||
-DGGML_OPENMP=OFF \
|
||||
-DLLAMA_BUILD_EXAMPLES=ON \
|
||||
-DLLAMA_BUILD_TOOLS=ON \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DCMAKE_SYSTEM_NAME=Linux \
|
||||
-DCMAKE_SYSTEM_PROCESSOR=ppc64 \
|
||||
-DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \
|
||||
-DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
-DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
# - name: Build
|
||||
# run: |
|
||||
# cmake -B build -DLLAMA_CURL=OFF \
|
||||
# -DCMAKE_BUILD_TYPE=Release \
|
||||
# -DGGML_VULKAN=ON \
|
||||
# -DGGML_OPENMP=OFF \
|
||||
# -DLLAMA_BUILD_EXAMPLES=ON \
|
||||
# -DLLAMA_BUILD_TOOLS=ON \
|
||||
# -DLLAMA_BUILD_TESTS=OFF \
|
||||
# -DCMAKE_SYSTEM_NAME=Linux \
|
||||
# -DCMAKE_SYSTEM_PROCESSOR=ppc64 \
|
||||
# -DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \
|
||||
# -DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \
|
||||
# -DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
# -DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
# cmake --build build --config Release -j $(nproc)
|
||||
|
||||
debian-13-loongarch64-cpu-cross:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
131
.github/workflows/build.yml
vendored
131
.github/workflows/build.yml
vendored
@@ -135,6 +135,69 @@ jobs:
|
||||
cd build
|
||||
ctest -L main --verbose --timeout 900
|
||||
|
||||
macOS-latest-cmake-arm64-webgpu:
|
||||
runs-on: macos-14
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: macOS-latest-cmake-arm64-webgpu
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
brew install curl
|
||||
|
||||
- name: Dawn Dependency
|
||||
id: dawn-depends
|
||||
run: |
|
||||
ARTIFACTS_JSON=$(curl -s -L \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
"https://api.github.com/repos/google/dawn/actions/artifacts")
|
||||
echo "Finding latest macos-latest-Release artifact..."
|
||||
DOWNLOAD_URL=$(echo "$ARTIFACTS_JSON" | jq -r '.artifacts
|
||||
| sort_by(.created_at)
|
||||
| reverse
|
||||
| map(select(.name | test("macos-latest-Release$")))
|
||||
| .[0].archive_download_url')
|
||||
if [ "$DOWNLOAD_URL" = "null" ] || [ -z "$DOWNLOAD_URL" ]; then
|
||||
echo "No suitable Dawn artifact found!"
|
||||
exit 1
|
||||
fi
|
||||
echo "Downloading from: $DOWNLOAD_URL"
|
||||
curl -L \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||
-o artifact.zip "$DOWNLOAD_URL"
|
||||
unzip artifact.zip
|
||||
mkdir dawn
|
||||
tar_file=$(find . -name '*.tar.gz' | head -n 1)
|
||||
echo "Extracting: $tar_file"
|
||||
tar -xvf "$tar_file" -C dawn --strip-components=1
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
export CMAKE_PREFIX_PATH=dawn
|
||||
cmake -B build -DGGML_WEBGPU=ON -DGGML_METAL=OFF -DGGML_BLAS=OFF
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
run: |
|
||||
cd build
|
||||
ctest -L main --verbose --timeout 900
|
||||
|
||||
ubuntu-cpu-cmake:
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -342,6 +405,72 @@ jobs:
|
||||
cd build
|
||||
export GGML_VK_VISIBLE_DEVICES=0
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 4200
|
||||
|
||||
ubuntu-22-cmake-webgpu:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-22-cmake-webgpu
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Vulkan SDK Dependencies
|
||||
id: vulkan-depends
|
||||
run: |
|
||||
wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add -
|
||||
sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libcurl4-openssl-dev
|
||||
|
||||
- name: Dawn Dependency
|
||||
id: dawn-depends
|
||||
run: |
|
||||
sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
|
||||
ARTIFACTS_JSON=$(curl -s -L \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
"https://api.github.com/repos/google/dawn/actions/artifacts")
|
||||
echo "Finding latest ubuntu-latest-Release artifact..."
|
||||
DOWNLOAD_URL=$(echo "$ARTIFACTS_JSON" | jq -r '.artifacts
|
||||
| sort_by(.created_at)
|
||||
| reverse
|
||||
| map(select(.name | test("ubuntu-latest-Release$")))
|
||||
| .[0].archive_download_url')
|
||||
if [ "$DOWNLOAD_URL" = "null" ] || [ -z "$DOWNLOAD_URL" ]; then
|
||||
echo "No suitable Dawn artifact found!"
|
||||
exit 1
|
||||
fi
|
||||
echo "Downloading from: $DOWNLOAD_URL"
|
||||
curl -L \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||
-o artifact.zip "$DOWNLOAD_URL"
|
||||
unzip artifact.zip
|
||||
mkdir dawn
|
||||
tar_file=$(find . -name '*.tar.gz' | head -n 1)
|
||||
echo "Extracting: $tar_file"
|
||||
tar -xvf "$tar_file" -C dawn --strip-components=1
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
export Dawn_DIR=dawn/lib64/cmake/Dawn
|
||||
cmake -B build -DGGML_WEBGPU=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
run: |
|
||||
cd build
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 3600
|
||||
|
||||
ubuntu-22-cmake-hip:
|
||||
@@ -386,7 +515,7 @@ jobs:
|
||||
|
||||
ubuntu-22-cmake-musa:
|
||||
runs-on: ubuntu-22.04
|
||||
container: mthreads/musa:rc4.0.1-mudnn-devel-ubuntu22.04
|
||||
container: mthreads/musa:rc4.2.0-devel-ubuntu22.04-amd64
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
2
.github/workflows/close-issue.yml
vendored
2
.github/workflows/close-issue.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
with:
|
||||
exempt-issue-labels: "refactor,help wanted,good first issue,research,bug,roadmap"
|
||||
exempt-issue-labels: "refactoring,help wanted,good first issue,research,bug,roadmap"
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
|
||||
40
.github/workflows/update-ops-docs.yml
vendored
Normal file
40
.github/workflows/update-ops-docs.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
name: Update Operations Documentation
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'docs/ops/**'
|
||||
- 'scripts/create_ops_docs.py'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'docs/ops/**'
|
||||
- 'scripts/create_ops_docs.py'
|
||||
|
||||
jobs:
|
||||
update-ops-docs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
- name: Generate operations documentation to temporary file
|
||||
run: |
|
||||
mkdir -p /tmp/ops_check
|
||||
./scripts/create_ops_docs.py /tmp/ops_check/ops.md
|
||||
|
||||
- name: Check if docs/ops.md matches generated version
|
||||
run: |
|
||||
if ! diff -q docs/ops.md /tmp/ops_check/ops.md; then
|
||||
echo "Operations documentation (docs/ops.md) is not up to date with the backend CSV files."
|
||||
echo "To fix: run ./scripts/create_ops_docs.py and commit the updated docs/ops.md along with your changes"
|
||||
echo "Differences found:"
|
||||
diff docs/ops.md /tmp/ops_check/ops.md || true
|
||||
exit 1
|
||||
fi
|
||||
echo "Operations documentation is up to date."
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -82,6 +82,7 @@ models/*
|
||||
models-mnt
|
||||
!models/.editorconfig
|
||||
!models/ggml-vocab-*.gguf*
|
||||
!models/templates
|
||||
|
||||
# Zig
|
||||
zig-out/
|
||||
|
||||
@@ -55,6 +55,17 @@
|
||||
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-apple-clang.cmake"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "x64-linux-gcc", "hidden": true,
|
||||
"cacheVariables": {
|
||||
"CMAKE_C_COMPILER": "gcc",
|
||||
"CMAKE_CXX_COMPILER": "g++"
|
||||
}
|
||||
},
|
||||
{ "name": "x64-linux-gcc-debug", "inherits": [ "base", "x64-linux-gcc", "debug" ] },
|
||||
{ "name": "x64-linux-gcc-release", "inherits": [ "base", "x64-linux-gcc", "release" ] },
|
||||
{ "name": "x64-linux-gcc-reldbg", "inherits": [ "base", "x64-linux-gcc", "reldbg" ] },
|
||||
{ "name": "x64-linux-gcc+static-release", "inherits": [ "base", "x64-linux-gcc", "release", "static" ] },
|
||||
|
||||
{ "name": "arm64-windows-llvm-debug", "inherits": [ "base", "arm64-windows-llvm", "debug" ] },
|
||||
{ "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] },
|
||||
|
||||
@@ -9,3 +9,4 @@
|
||||
/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
|
||||
/ggml/src/ggml-opt.cpp @JohannesGaessler
|
||||
/ggml/src/gguf.cpp @JohannesGaessler
|
||||
/ggml/src/ggml-vulkan/ @0cc4m
|
||||
|
||||
16
README.md
16
README.md
@@ -6,9 +6,9 @@
|
||||
[](https://github.com/ggml-org/llama.cpp/releases)
|
||||
[](https://github.com/ggml-org/llama.cpp/actions/workflows/server.yml)
|
||||
|
||||
[Roadmap](https://github.com/users/ggerganov/projects/7) / [Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml)
|
||||
[Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml) / [ops](https://github.com/ggml-org/llama.cpp/blob/master/docs/ops.md)
|
||||
|
||||
Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) in pure C/C++
|
||||
LLM inference in C/C++
|
||||
|
||||
## Recent API changes
|
||||
|
||||
@@ -17,10 +17,9 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
|
||||
|
||||
## Hot topics
|
||||
|
||||
- 🔥 Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
|
||||
- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141](https://github.com/ggml-org/llama.cpp/pull/13141)), `libllava` will be deprecated
|
||||
- Hot PRs: [All](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+) | [Open](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+is%3Aopen)
|
||||
- Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
|
||||
- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode
|
||||
- Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639
|
||||
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
|
||||
- Introducing GGUF-my-LoRA https://github.com/ggml-org/llama.cpp/discussions/10123
|
||||
- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggml-org/llama.cpp/discussions/9669
|
||||
@@ -134,6 +133,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
|
||||
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
|
||||
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
|
||||
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
|
||||
|
||||
#### Multimodal
|
||||
|
||||
@@ -269,6 +269,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
| [Vulkan](docs/build.md#vulkan) | GPU |
|
||||
| [CANN](docs/build.md#cann) | Ascend NPU |
|
||||
| [OpenCL](docs/backend/OPENCL.md) | Adreno GPU |
|
||||
| [WebGPU [In Progress]](docs/build.md#webgpu) | All |
|
||||
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All |
|
||||
|
||||
## Obtaining and quantizing models
|
||||
@@ -434,7 +435,7 @@ To learn more about model quantization, [read this documentation](tools/quantize
|
||||
|
||||
## [`llama-perplexity`](tools/perplexity)
|
||||
|
||||
#### A tool for measuring the perplexity [^1][^2] (and other quality metrics) of a model over a given text.
|
||||
#### A tool for measuring the [perplexity](tools/perplexity/README.md) [^1] (and other quality metrics) of a model over a given text.
|
||||
|
||||
- <details open>
|
||||
<summary>Measure the perplexity over a text file</summary>
|
||||
@@ -457,8 +458,7 @@ To learn more about model quantization, [read this documentation](tools/quantize
|
||||
|
||||
</details>
|
||||
|
||||
[^1]: [tools/perplexity/README.md](./tools/perplexity/README.md)
|
||||
[^2]: [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity)
|
||||
[^1]: [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity)
|
||||
|
||||
## [`llama-bench`](tools/llama-bench)
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ docker run --privileged -it \
|
||||
-v $HOME/llama.cpp/ci-cache:/ci-cache \
|
||||
-v $HOME/llama.cpp/ci-results:/ci-results \
|
||||
-v $PWD:/ws -w /ws \
|
||||
mthreads/musa:rc4.0.1-mudnn-devel-ubuntu22.04
|
||||
mthreads/musa:rc4.2.0-devel-ubuntu22.04-amd64
|
||||
```
|
||||
|
||||
Inside the container, execute the following commands:
|
||||
|
||||
@@ -16,6 +16,9 @@
|
||||
# # with VULKAN support
|
||||
# GG_BUILD_VULKAN=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
#
|
||||
# # with WebGPU support
|
||||
# GG_BUILD_WEBGPU=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
#
|
||||
# # with MUSA support
|
||||
# GG_BUILD_MUSA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
#
|
||||
@@ -81,6 +84,10 @@ if [ ! -z ${GG_BUILD_VULKAN} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_VULKAN=1"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_WEBGPU} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_WEBGPU=1"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_MUSA} ]; then
|
||||
# Use qy1 by default (MTT S80)
|
||||
MUSA_ARCH=${MUSA_ARCH:-21}
|
||||
|
||||
@@ -86,8 +86,7 @@ if (LLAMA_CURL)
|
||||
endif()
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
|
||||
include_directories(${CURL_INCLUDE_DIRS})
|
||||
find_library(CURL_LIBRARY curl REQUIRED)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY})
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARIES})
|
||||
endif ()
|
||||
|
||||
if (LLAMA_LLGUIDANCE)
|
||||
@@ -112,13 +111,13 @@ if (LLAMA_LLGUIDANCE)
|
||||
|
||||
ExternalProject_Add(llguidance_ext
|
||||
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
|
||||
# v0.7.20 (+ fix to build on GCC 15):
|
||||
GIT_TAG b5b8b64dba11c4e4ee6b1d1450d3a3ae279891e8
|
||||
# v1.0.1:
|
||||
GIT_TAG d795912fedc7d393de740177ea9ea761e7905774
|
||||
PREFIX ${CMAKE_BINARY_DIR}/llguidance
|
||||
SOURCE_DIR ${LLGUIDANCE_SRC}
|
||||
BUILD_IN_SOURCE TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND cargo build --release
|
||||
BUILD_COMMAND cargo build --release --package llguidance
|
||||
INSTALL_COMMAND ""
|
||||
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME} ${LLGUIDANCE_PATH}/llguidance.h
|
||||
UPDATE_COMMAND ""
|
||||
|
||||
@@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.swa_full = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_SWA_FULL"));
|
||||
add_opt(common_arg(
|
||||
{"--kv-unified", "-kvu"},
|
||||
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
|
||||
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_unified ? "true" : "false"),
|
||||
[](common_params & params) {
|
||||
params.kv_unified = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_SPLIT"));
|
||||
add_opt(common_arg(
|
||||
{"--no-context-shift"},
|
||||
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||
@@ -1604,7 +1612,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.antiprompt.emplace_back(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-sp", "--special"},
|
||||
string_format("special tokens output enabled (default: %s)", params.special ? "true" : "false"),
|
||||
@@ -2647,6 +2655,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.i_chunk = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
|
||||
add_opt(common_arg(
|
||||
{"--show-statistics"},
|
||||
string_format("show imatrix statistics and then exit (default: %s)", params.show_statistics ? "true" : "false"),
|
||||
[](common_params & params) {
|
||||
params.show_statistics = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
|
||||
add_opt(common_arg(
|
||||
{"--parse-special"},
|
||||
string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"),
|
||||
@@ -2734,6 +2749,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.public_path = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH"));
|
||||
add_opt(common_arg(
|
||||
{"--api-prefix"}, "PREFIX",
|
||||
string_format("prefix path the server serves from, without the trailing slash (default: %s)", params.api_prefix.c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.api_prefix = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX"));
|
||||
add_opt(common_arg(
|
||||
{"--no-webui"},
|
||||
string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"),
|
||||
@@ -3416,5 +3438,34 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
// diffusion parameters
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-steps" }, "N",
|
||||
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
|
||||
[](common_params & params, int value) { params.diffusion.steps = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-eps" }, "F",
|
||||
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-algorithm" }, "N",
|
||||
string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)",
|
||||
params.diffusion.algorithm),
|
||||
[](common_params & params, int value) { params.diffusion.algorithm = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-alg-temp" }, "F",
|
||||
string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-visual" },
|
||||
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
|
||||
params.diffusion.visual_mode ? "true" : "false"),
|
||||
[](common_params & params) { params.diffusion.visual_mode = true; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
|
||||
return ctx_arg;
|
||||
}
|
||||
|
||||
@@ -448,6 +448,15 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
|
||||
bool string_remove_suffix(std::string & str, const std::string_view & suffix) {
|
||||
bool has_suffix = string_ends_with(str, suffix);
|
||||
if (has_suffix) {
|
||||
str = str.substr(0, str.size() - suffix.size());
|
||||
}
|
||||
return has_suffix;
|
||||
}
|
||||
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
|
||||
if (!str.empty() && !stop.empty()) {
|
||||
const char text_last_char = str.back();
|
||||
@@ -1005,15 +1014,21 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias.push_back({i, -INFINITY});
|
||||
}
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
// add EOG biases to the active set of logit biases
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
||||
}
|
||||
|
||||
if (params.sampling.penalty_last_n == -1) {
|
||||
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||
@@ -1157,6 +1172,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.op_offload = !params.no_op_offload;
|
||||
cparams.swa_full = params.swa_full;
|
||||
cparams.kv_unified = params.kv_unified;
|
||||
|
||||
cparams.type_k = params.cache_type_k;
|
||||
cparams.type_v = params.cache_type_v;
|
||||
|
||||
@@ -81,6 +81,7 @@ enum llama_example {
|
||||
LLAMA_EXAMPLE_LOOKUP,
|
||||
LLAMA_EXAMPLE_PARALLEL,
|
||||
LLAMA_EXAMPLE_TTS,
|
||||
LLAMA_EXAMPLE_DIFFUSION,
|
||||
|
||||
LLAMA_EXAMPLE_COUNT,
|
||||
};
|
||||
@@ -177,7 +178,8 @@ struct common_params_sampling {
|
||||
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
|
||||
std::set<llama_token> preserved_tokens;
|
||||
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
// print the parameters into a string
|
||||
std::string print() const;
|
||||
@@ -217,6 +219,14 @@ struct common_params_vocoder {
|
||||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||
};
|
||||
|
||||
struct common_params_diffusion {
|
||||
int32_t steps = 64; // number of diffusion steps
|
||||
float eps = 1e-3f; // epsilon for timesteps
|
||||
int32_t algorithm = 0; // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY)
|
||||
float alg_temp = 0.0f; // algorithm temperature
|
||||
bool visual_mode = false; // show progressive diffusion on screen
|
||||
};
|
||||
|
||||
enum common_reasoning_format {
|
||||
COMMON_REASONING_FORMAT_NONE,
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
|
||||
@@ -268,6 +278,7 @@ struct common_params {
|
||||
struct common_params_sampling sampling;
|
||||
struct common_params_speculative speculative;
|
||||
struct common_params_vocoder vocoder;
|
||||
struct common_params_diffusion diffusion;
|
||||
|
||||
struct common_params_model model;
|
||||
|
||||
@@ -330,6 +341,7 @@ struct common_params {
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool ctx_shift = true; // context shift on inifinite text generation
|
||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
bool kv_unified = false; // enable unified KV cache
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
@@ -370,6 +382,7 @@ struct common_params {
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = ""; // NOLINT
|
||||
std::string api_prefix = ""; // NOLINT
|
||||
std::string chat_template = ""; // NOLINT
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
@@ -419,9 +432,10 @@ struct common_params {
|
||||
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
|
||||
int32_t i_chunk = 0; // start processing from this chunk
|
||||
|
||||
bool process_output = false; // collect data for the output tensor
|
||||
bool compute_ppl = true; // whether to compute perplexity
|
||||
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
|
||||
bool process_output = false; // collect data for the output tensor
|
||||
bool compute_ppl = true; // whether to compute perplexity
|
||||
bool show_statistics = false; // show imatrix statistics per tensor
|
||||
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
|
||||
|
||||
// cvector-generator params
|
||||
int n_pca_batch = 100;
|
||||
@@ -521,6 +535,7 @@ static bool string_starts_with(const std::string & str,
|
||||
|
||||
// While we wait for C++20's std::string::ends_with...
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
||||
bool string_remove_suffix(std::string & str, const std::string_view & suffix);
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
||||
|
||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,6 @@ import pathlib
|
||||
import re
|
||||
|
||||
import requests
|
||||
import sys
|
||||
import json
|
||||
import shutil
|
||||
import argparse
|
||||
@@ -69,8 +68,7 @@ args = parser.parse_args()
|
||||
hf_token = args.hf_token if args.hf_token is not None else hf_token
|
||||
|
||||
if hf_token is None:
|
||||
logger.error("HF token is required. Please provide it as an argument or set it in ~/.cache/huggingface/token")
|
||||
sys.exit(1)
|
||||
logger.warning("HF token not found. You can provide it as an argument or set it in ~/.cache/huggingface/token")
|
||||
|
||||
# TODO: this string has to exercise as much pre-tokenizer functionality as possible
|
||||
# will be updated with time - contributions welcome
|
||||
@@ -128,6 +126,10 @@ models = [
|
||||
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
|
||||
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
|
||||
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
|
||||
{"name": "a.x-4.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/skt/A.X-4.0", },
|
||||
{"name": "midm-2.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct", },
|
||||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
@@ -137,11 +139,18 @@ pre_computed_hashes = [
|
||||
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
|
||||
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
|
||||
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
|
||||
# falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"},
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"},
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"},
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
|
||||
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
|
||||
]
|
||||
|
||||
|
||||
def download_file_with_auth(url, token, save_path):
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
headers = {"Authorization": f"Bearer {token}"} if token else None
|
||||
response = sess.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
@@ -222,7 +231,7 @@ for model in models:
|
||||
# generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function:
|
||||
|
||||
src_ifs = ""
|
||||
for model in [*all_models, *pre_computed_hashes]:
|
||||
for model in [*pre_computed_hashes, *all_models]:
|
||||
name = model["name"]
|
||||
tokt = model["tokt"]
|
||||
chkhsh = model.get("chkhsh")
|
||||
@@ -230,11 +239,6 @@ for model in [*all_models, *pre_computed_hashes]:
|
||||
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
|
||||
continue
|
||||
|
||||
# Skip if the tokenizer folder does not exist or there are other download issues previously
|
||||
if not os.path.exists(f"models/tokenizers/{name}"):
|
||||
logger.warning(f"Directory for tokenizer {name} not found. Skipping...")
|
||||
continue
|
||||
|
||||
# create the tokenizer
|
||||
if chkhsh is not None:
|
||||
# if the model has a pre-computed hash, use it
|
||||
@@ -244,15 +248,19 @@ for model in [*all_models, *pre_computed_hashes]:
|
||||
chkhsh = existing_models[name]
|
||||
else:
|
||||
# otherwise, compute the hash of the tokenizer
|
||||
|
||||
# Fail if the tokenizer folder with config does not exist or there are other download issues previously
|
||||
if not os.path.isfile(f"models/tokenizers/{name}/tokenizer_config.json"):
|
||||
raise OSError(f"Config for tokenizer {name} not found. The model may not exist or is not accessible with the provided token.")
|
||||
|
||||
try:
|
||||
logger.info(f"Loading tokenizer from {f'models/tokenizers/{name}'}...")
|
||||
if name == "t5":
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||
except OSError as e:
|
||||
logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
|
||||
continue # Skip to the next model if the tokenizer can't be loaded
|
||||
except Exception as e:
|
||||
raise OSError(f"Error loading tokenizer for model {name}.") from e
|
||||
|
||||
chktok = tokenizer.encode(CHK_TXT)
|
||||
chkhsh = sha256(str(chktok).encode()).hexdigest()
|
||||
|
||||
@@ -42,14 +42,14 @@ cmake --build build --config Release -j $(nproc)
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
```
|
||||
|
||||
- By default, NNPA is enabled when available. To disable it (not recommended):
|
||||
- By default, NNPA is disabled by default. To enable it:
|
||||
|
||||
```bash
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_BLAS=ON \
|
||||
-DGGML_BLAS_VENDOR=OpenBLAS \
|
||||
-DGGML_NNPA=OFF
|
||||
-DGGML_NNPA=ON
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
```
|
||||
@@ -84,9 +84,9 @@ All models need to be converted to Big-Endian. You can achieve this in three cas
|
||||
|
||||

|
||||
|
||||
You can find popular models pre-converted and verified at [s390x Ready Models](https://huggingface.co/collections/taronaeo/s390x-ready-models-672765393af438d0ccb72a08).
|
||||
You can find popular models pre-converted and verified at [s390x Verified Models](https://huggingface.co/collections/taronaeo/s390x-verified-models-672765393af438d0ccb72a08) or [s390x Runnable Models](https://huggingface.co/collections/taronaeo/s390x-runnable-models-686e951824198df12416017e).
|
||||
|
||||
These models have already been converted from `safetensors` to `GGUF Big-Endian` and their respective tokenizers verified to run correctly on IBM z15 and later system.
|
||||
These models have already been converted from `safetensors` to `GGUF` Big-Endian and their respective tokenizers verified to run correctly on IBM z15 and later system.
|
||||
|
||||
2. **Convert safetensors model to GGUF Big-Endian directly (recommended)**
|
||||
|
||||
@@ -94,6 +94,14 @@ All models need to be converted to Big-Endian. You can achieve this in three cas
|
||||
|
||||
The model you are trying to convert must be in `safetensors` file format (for example [IBM Granite 3.3 2B](https://huggingface.co/ibm-granite/granite-3.3-2b-instruct)). Make sure you have downloaded the model repository for this case.
|
||||
|
||||
Ensure that you have installed the required packages in advance
|
||||
|
||||
```bash
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
Convert the `safetensors` model to `GGUF`
|
||||
|
||||
```bash
|
||||
python3 convert_hf_to_gguf.py \
|
||||
--outfile model-name-be.f16.gguf \
|
||||
@@ -116,7 +124,7 @@ All models need to be converted to Big-Endian. You can achieve this in three cas
|
||||
|
||||

|
||||
|
||||
The model you are trying to convert must be in `gguf` file format (for example [IBM Granite 3.3 2B](https://huggingface.co/ibm-granite/granite-3.3-2b-instruct-GGUF)). Make sure you have downloaded the model file for this case.
|
||||
The model you are trying to convert must be in `gguf` file format (for example [IBM Granite 3.3 2B GGUF](https://huggingface.co/ibm-granite/granite-3.3-2b-instruct-GGUF)). Make sure you have downloaded the model file for this case.
|
||||
|
||||
```bash
|
||||
python3 gguf-py/gguf/scripts/gguf_convert_endian.py model-name.f16.gguf BIG
|
||||
@@ -141,15 +149,15 @@ Only available in IBM z15 or later system with the `-DGGML_VXE=ON` (turned on by
|
||||
|
||||
### 2. NNPA Vector Intrinsics Acceleration
|
||||
|
||||
Only available in IBM z16 or later system with the `-DGGML_NNPA=ON` (turned on when available) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs can still run but will use a scalar implementation.
|
||||
Only available in IBM z16 or later system with the `-DGGML_NNPA=ON` (turned off by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs can still run but will use a scalar implementation.
|
||||
|
||||
### 3. zDNN Accelerator
|
||||
|
||||
_Only available in IBM z16 or later system. No direction at the moment._
|
||||
_Only available in IBM z16 / LinuxONE 4 or later system. No support currently available._
|
||||
|
||||
### 4. Spyre Accelerator
|
||||
|
||||
_No direction at the moment._
|
||||
_Only available with IBM z17 / LinuxONE 5 or later system. No support currently available._
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
@@ -189,6 +197,26 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
|
||||
Answer: Please ensure that your GCC compiler is of minimum GCC 15.1.0 version, and have `binutils` updated to the latest version. If this does not fix the problem, kindly open an issue.
|
||||
|
||||
4. Failing to install the `sentencepiece` package using GCC 15+
|
||||
|
||||
Answer: The `sentencepiece` team are aware of this as seen in [this issue](https://github.com/google/sentencepiece/issues/1108).
|
||||
|
||||
As a temporary workaround, please run the installation command with the following environment variables.
|
||||
|
||||
```bash
|
||||
export CXXFLAGS="-include cstdint"
|
||||
```
|
||||
|
||||
For example,
|
||||
|
||||
```bash
|
||||
CXXFLAGS="-include cstdint" pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
5. `-DGGML_NNPA=ON` generates gibberish output
|
||||
|
||||
Answer: We are aware of this as detailed in [this issue](https://github.com/ggml-org/llama.cpp/issues/14877). Please either try reducing the number of threads, or disable the compile option using `-DGGML_NNPA=OFF`.
|
||||
|
||||
## Getting Help on IBM Z & LinuxONE
|
||||
|
||||
1. **Bugs, Feature Requests**
|
||||
@@ -244,3 +272,5 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
- ✅ - acceleration available
|
||||
- 🚫 - acceleration unavailable, will still run using scalar implementation
|
||||
- ❓ - acceleration unknown, please contribute if you can test it yourself
|
||||
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on July 25, 2025.
|
||||
|
||||
@@ -68,6 +68,9 @@ cmake --build build --config Release
|
||||
cmake --build build-x64-windows-llvm-release
|
||||
```
|
||||
- Curl usage is enabled by default and can be turned off with `-DLLAMA_CURL=OFF`. Otherwise you need to install development libraries for libcurl.
|
||||
- **Debian / Ubuntu:** `sudo apt-get install libcurl4-openssl-dev` # (or `libcurl4-gnutls-dev` if you prefer GnuTLS)
|
||||
- **Fedora / RHEL / Rocky / Alma:** `sudo dnf install libcurl-devel`
|
||||
- **Arch / Manjaro:** `sudo pacman -S curl` # includes libcurl headers
|
||||
|
||||
## BLAS Build
|
||||
|
||||
@@ -305,9 +308,8 @@ On Linux it is possible to use unified memory architecture (UMA) to share main m
|
||||
|
||||
## Vulkan
|
||||
|
||||
**Windows**
|
||||
|
||||
### w64devkit
|
||||
### For Windows Users:
|
||||
**w64devkit**
|
||||
|
||||
Download and extract [`w64devkit`](https://github.com/skeeto/w64devkit/releases).
|
||||
|
||||
@@ -334,7 +336,7 @@ cmake -B build -DGGML_VULKAN=ON
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
### Git Bash MINGW64
|
||||
**Git Bash MINGW64**
|
||||
|
||||
Download and install [`Git-SCM`](https://git-scm.com/downloads/win) with the default settings
|
||||
|
||||
@@ -357,7 +359,8 @@ Now you can load the model in conversation mode using `Vulkan`
|
||||
build/bin/Release/llama-cli -m "[PATH TO MODEL]" -ngl 100 -c 16384 -t 10 -n -2 -cnv
|
||||
```
|
||||
|
||||
### MSYS2
|
||||
**MSYS2**
|
||||
|
||||
Install [MSYS2](https://www.msys2.org/) and then run the following commands in a UCRT terminal to install dependencies.
|
||||
```sh
|
||||
pacman -S git \
|
||||
@@ -373,9 +376,9 @@ cmake -B build -DGGML_VULKAN=ON
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
**With docker**:
|
||||
### For Docker users:
|
||||
|
||||
You don't need to install Vulkan SDK. It will be installed inside the container.
|
||||
You don't need to install the Vulkan SDK. It will be installed inside the container.
|
||||
|
||||
```sh
|
||||
# Build the image
|
||||
@@ -385,32 +388,29 @@ docker build -t llama-cpp-vulkan --target light -f .devops/vulkan.Dockerfile .
|
||||
docker run -it --rm -v "$(pwd):/app:Z" --device /dev/dri/renderD128:/dev/dri/renderD128 --device /dev/dri/card1:/dev/dri/card1 llama-cpp-vulkan -m "/app/models/YOUR_MODEL_FILE" -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33
|
||||
```
|
||||
|
||||
**Without docker**:
|
||||
### For Linux users:
|
||||
|
||||
Firstly, you need to make sure you have installed [Vulkan SDK](https://vulkan.lunarg.com/doc/view/latest/linux/getting_started_ubuntu.html)
|
||||
First, follow the official LunarG instructions for the installation and setup of the Vulkan SDK in the [Getting Started with the Linux Tarball Vulkan SDK](https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html) guide.
|
||||
|
||||
For example, on Ubuntu 22.04 (jammy), use the command below:
|
||||
> [!IMPORTANT]
|
||||
> After completing the first step, ensure that you have used the `source` command on the `setup_env.sh` file inside of the Vulkan SDK in your current terminal session. Otherwise, the build won't work. Additionally, if you close out of your terminal, you must perform this step again if you intend to perform a build. However, there are ways to make this persistent. Refer to the Vulkan SDK guide linked in the first step for more information about any of this.
|
||||
|
||||
Second, after verifying that you have followed all of the SDK installation/setup steps, use this command to make sure before proceeding:
|
||||
```bash
|
||||
wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add -
|
||||
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list
|
||||
apt update -y
|
||||
apt-get install -y vulkan-sdk
|
||||
# To verify the installation, use the command below:
|
||||
vulkaninfo
|
||||
```
|
||||
|
||||
Alternatively your package manager might be able to provide the appropriate libraries.
|
||||
For example for Ubuntu 22.04 you can install `libvulkan-dev` instead.
|
||||
For Fedora 40, you can install `vulkan-devel`, `glslc` and `glslang` packages.
|
||||
|
||||
Then, build llama.cpp using the cmake command below:
|
||||
|
||||
Then, assuming you have `cd` into your llama.cpp folder and there are no errors with running `vulkaninfo`, you can proceed to build llama.cpp using the CMake commands below:
|
||||
```bash
|
||||
cmake -B build -DGGML_VULKAN=1
|
||||
cmake --build build --config Release
|
||||
# Test the output binary (with "-ngl 33" to offload all layers to GPU)
|
||||
./bin/llama-cli -m "PATH_TO_MODEL" -p "Hi you how are you" -n 50 -e -ngl 33 -t 4
|
||||
```
|
||||
|
||||
Finally, after finishing your build, you should be able to do something like this:
|
||||
```bash
|
||||
# Test the output binary
|
||||
# "-ngl 99" should offload all of the layers to GPU for most (if not all) models.
|
||||
./build/bin/llama-cli -m "PATH_TO_MODEL" -p "Hi you how are you" -ngl 99
|
||||
|
||||
# You should see in the output, ggml_vulkan detected your GPU. For example:
|
||||
# ggml_vulkan: Using Intel(R) Graphics (ADL GT2) | uma: 1 | fp16: 1 | warp size: 32
|
||||
@@ -557,6 +557,23 @@ ninja
|
||||
|
||||
To read documentation for how to build on Android, [click here](./android.md)
|
||||
|
||||
## WebGPU [In Progress]
|
||||
|
||||
The WebGPU backend relies on [Dawn](https://dawn.googlesource.com/dawn). Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/quickstart-cmake.md) to install Dawn locally so that llama.cpp can find it using CMake. The currrent implementation is up-to-date with Dawn commit `bed1a61`.
|
||||
|
||||
In the llama.cpp directory, build with CMake:
|
||||
|
||||
```
|
||||
cmake -B build -DGGML_WEBGPU=ON
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
### Browser Support
|
||||
|
||||
WebGPU allows cross-platform access to the GPU from supported browsers. We utilize [Emscripten](https://emscripten.org/) to compile ggml's WebGPU backend to WebAssembly. Emscripten does not officially support WebGPU bindings yet, but Dawn currently maintains its own WebGPU bindings called emdawnwebgpu.
|
||||
|
||||
Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/src/emdawnwebgpu/) to download or build the emdawnwebgpu package (Note that it might be safer to build the emdawbwebgpu package locally, so that it stays in sync with the version of Dawn you have installed above). When building using CMake, the path to the emdawnwebgpu port file needs to be set with the flag `EMDAWNWEBGPU_DIR`.
|
||||
|
||||
## IBM Z & LinuxONE
|
||||
|
||||
To read documentation for how to build on IBM Z & LinuxONE, [click here](./build-s390x.md)
|
||||
|
||||
@@ -23,11 +23,19 @@ The convert script reads the model configuration, tokenizer, tensor names+data a
|
||||
|
||||
The required steps to implement for an HF model are:
|
||||
|
||||
1. Define the model `Model.register` annotation in a new `Model` subclass, example:
|
||||
1. Define the model `ModelBase.register` annotation in a new `TextModel` or `MmprojModel` subclass, example:
|
||||
|
||||
```python
|
||||
@Model.register("MyModelForCausalLM")
|
||||
class MyModel(Model):
|
||||
@ModelBase.register("MyModelForCausalLM")
|
||||
class MyModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.MYMODEL
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```python
|
||||
@ModelBase.register("MyModelForConditionalGeneration")
|
||||
class MyModel(MmprojModel):
|
||||
model_arch = gguf.MODEL_ARCH.MYMODEL
|
||||
```
|
||||
|
||||
@@ -75,28 +83,31 @@ block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
|
||||
`transformer.blocks.{bid}.norm_1` will be mapped to `blk.{bid}.attn_norm` in GGUF.
|
||||
|
||||
Depending on the model configuration, tokenizer, code and tensors layout, you will have to override:
|
||||
- `Model#set_gguf_parameters`
|
||||
- `Model#set_vocab`
|
||||
- `Model#write_tensors`
|
||||
- `TextModel#set_gguf_parameters`
|
||||
- `MmprojModel#set_gguf_parameters`
|
||||
- `ModelBase#set_vocab`
|
||||
- `ModelBase#modify_tensors`
|
||||
|
||||
NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the convention and several tools like `quantize` expect this to proceed the weights.
|
||||
|
||||
### 2. Define the model architecture in `llama.cpp`
|
||||
|
||||
The model params and tensors layout must be defined in `llama.cpp`:
|
||||
1. Define a new `llm_arch`
|
||||
2. Define the tensors layout in `LLM_TENSOR_NAMES`
|
||||
3. Add any non-standard metadata in `llm_load_hparams`
|
||||
4. Create the tensors for inference in `llm_load_tensors`
|
||||
5. If the model has a RoPE operation, add the rope type in `llama_rope_type`
|
||||
The model params and tensors layout must be defined in `llama.cpp` source files:
|
||||
1. Define a new `llm_arch` enum value in `src/llama-arch.h`.
|
||||
2. In `src/llama-arch.cpp`:
|
||||
- Add the architecture name to the `LLM_ARCH_NAMES` map.
|
||||
- Add the tensor mappings to the `LLM_TENSOR_NAMES` map.
|
||||
3. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`.
|
||||
4. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`.
|
||||
|
||||
NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorch` dimensions.
|
||||
|
||||
### 3. Build the GGML graph implementation
|
||||
|
||||
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.
|
||||
|
||||
Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`.
|
||||
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `src/llama-model.cpp`.
|
||||
Create a new struct that inherits from `llm_graph_context` and implement the graph-building logic in its constructor.
|
||||
Have a look at existing implementations like `llm_build_llama`, `llm_build_dbrx` or `llm_build_bert`.
|
||||
Then, in the `llama_model::build_graph` method, add a case for your architecture to instantiate your new graph-building struct.
|
||||
|
||||
Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ You may want to pass in some different `ARGS`, depending on the MUSA environment
|
||||
|
||||
The defaults are:
|
||||
|
||||
- `MUSA_VERSION` set to `rc4.0.1`
|
||||
- `MUSA_VERSION` set to `rc4.2.0`
|
||||
|
||||
The resulting images, are essentially the same as the non-MUSA images:
|
||||
|
||||
|
||||
@@ -97,6 +97,9 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
||||
# Qwen2-Audio and SeaLLM-Audio
|
||||
# note: no pre-quantized GGUF this model, as they have very poor result
|
||||
# ref: https://github.com/ggml-org/llama.cpp/pull/13760
|
||||
|
||||
# Mistral's Voxtral
|
||||
(tool_name) -hf ggml-org/Voxtral-Mini-3B-2507-GGUF
|
||||
```
|
||||
|
||||
**Mixed modalities**:
|
||||
|
||||
102
docs/ops.md
Normal file
102
docs/ops.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# GGML Operations
|
||||
|
||||
List of GGML operations and backend support status.
|
||||
|
||||
## How to add a backend to this table:
|
||||
|
||||
1. Run `test-backend-ops support --output csv` with your backend name and redirect output to a csv file in `docs/ops/` (e.g., `docs/ops/CUDA.csv`)
|
||||
2. Regenerate `/docs/ops.md` via `./scripts/create_ops_docs.py`
|
||||
|
||||
Legend:
|
||||
- ✅ Fully supported by this backend
|
||||
- 🟡 Partially supported by this backend
|
||||
- ❌ Not supported by this backend
|
||||
|
||||
| Operation | BLAS | CPU | CUDA | Metal | SYCL | Vulkan |
|
||||
|-----------|------|------|------|------|------|------|
|
||||
| ABS | ❌ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| ADD | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 |
|
||||
| CONCAT | ❌ | ✅ | 🟡 | ✅ | 🟡 | ✅ |
|
||||
| CONT | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 |
|
||||
| CONV_2D | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ |
|
||||
| CONV_2D_DW | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ |
|
||||
| DIV | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ |
|
||||
| DUP | ❌ | ✅ | 🟡 | 🟡 | ✅ | 🟡 |
|
||||
| ELU | ❌ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| EXP | ❌ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | ✅ | 🟡 | 🟡 | ❌ | 🟡 |
|
||||
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 |
|
||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 |
|
||||
| GELU | ❌ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| GELU_ERF | ❌ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| GELU_QUICK | ❌ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| GET_ROWS | ❌ | ✅ | 🟡 | ✅ | 🟡 | 🟡 |
|
||||
| GET_ROWS_BACK | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| HARDSIGMOID | ❌ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ |
|
||||
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| LOG | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ |
|
||||
| NEG | ❌ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 |
|
||||
| OPT_STEP_ADAMW | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ |
|
||||
| OUT_PROD | 🟡 | 🟡 | 🟡 | ❌ | 🟡 | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| REGLU | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 |
|
||||
| RELU | ❌ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| REPEAT | ❌ | ✅ | 🟡 | ✅ | ✅ | 🟡 |
|
||||
| REPEAT_BACK | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ |
|
||||
| RMS_NORM_BACK | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| ROLL | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ |
|
||||
| ROPE | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| ROPE_BACK | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ |
|
||||
| RWKV_WKV6 | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| RWKV_WKV7 | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| SCALE | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| SET | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| SGN | ❌ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| SILU | ❌ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| SILU_BACK | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ |
|
||||
| SIN | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 |
|
||||
| SOFT_MAX | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ |
|
||||
| SOFT_MAX_BACK | ❌ | 🟡 | 🟡 | ❌ | ❌ | ✅ |
|
||||
| SQR | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 |
|
||||
| SQRT | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ |
|
||||
| SUM | ❌ | ✅ | ✅ | ❌ | ✅ | ✅ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 |
|
||||
| TANH | ❌ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| UPSCALE | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ |
|
||||
8133
docs/ops/BLAS.csv
Normal file
8133
docs/ops/BLAS.csv
Normal file
File diff suppressed because it is too large
Load Diff
7349
docs/ops/CPU.csv
Normal file
7349
docs/ops/CPU.csv
Normal file
File diff suppressed because it is too large
Load Diff
7349
docs/ops/CUDA.csv
Normal file
7349
docs/ops/CUDA.csv
Normal file
File diff suppressed because it is too large
Load Diff
8133
docs/ops/Metal.csv
Normal file
8133
docs/ops/Metal.csv
Normal file
File diff suppressed because it is too large
Load Diff
8133
docs/ops/SYCL.csv
Normal file
8133
docs/ops/SYCL.csv
Normal file
File diff suppressed because it is too large
Load Diff
8133
docs/ops/Vulkan.csv
Normal file
8133
docs/ops/Vulkan.csv
Normal file
File diff suppressed because it is too large
Load Diff
@@ -33,6 +33,7 @@ else()
|
||||
add_subdirectory(speculative-simple)
|
||||
add_subdirectory(gen-docs)
|
||||
add_subdirectory(training)
|
||||
add_subdirectory(diffusion)
|
||||
if (NOT GGML_BACKEND_DL)
|
||||
add_subdirectory(convert-llama2c-to-ggml)
|
||||
# these examples use the backends directly and cannot be built with dynamic loading
|
||||
|
||||
5
examples/diffusion/CMakeLists.txt
Normal file
5
examples/diffusion/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
set(TARGET llama-diffusion-cli)
|
||||
add_executable(${TARGET} diffusion-cli.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
507
examples/diffusion/diffusion-cli.cpp
Normal file
507
examples/diffusion/diffusion-cli.cpp
Normal file
@@ -0,0 +1,507 @@
|
||||
#include "arg.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <limits.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
|
||||
typedef bool (*diffusion_step_callback_t)(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
void * user_data);
|
||||
|
||||
enum diffusion_alg {
|
||||
DIFFUSION_ALG_ORIGIN = 0,
|
||||
DIFFUSION_ALG_MASKGIT_PLUS = 1,
|
||||
DIFFUSION_ALG_TOPK_MARGIN = 2,
|
||||
DIFFUSION_ALG_ENTROPY = 3,
|
||||
};
|
||||
|
||||
struct diffusion_params {
|
||||
int32_t steps;
|
||||
float eps;
|
||||
float temperature;
|
||||
float top_p;
|
||||
int32_t top_k;
|
||||
llama_token mask_token_id;
|
||||
enum diffusion_alg algorithm;
|
||||
float alg_temp;
|
||||
diffusion_step_callback_t step_callback;
|
||||
void * step_callback_user_data;
|
||||
int32_t seed;
|
||||
};
|
||||
|
||||
|
||||
static diffusion_params diffusion_default_params() {
|
||||
diffusion_params params = {};
|
||||
params.steps = 64;
|
||||
params.eps = 1e-3f;
|
||||
params.temperature = 0.2f;
|
||||
params.top_p = 0.95f;
|
||||
params.top_k = 0;
|
||||
params.mask_token_id = LLAMA_TOKEN_NULL;
|
||||
params.algorithm = DIFFUSION_ALG_ORIGIN;
|
||||
params.alg_temp = 0.0f;
|
||||
params.step_callback = nullptr;
|
||||
params.step_callback_user_data = nullptr;
|
||||
params.seed = 0;
|
||||
return params;
|
||||
}
|
||||
|
||||
static void diffusion_generate(llama_context * ctx,
|
||||
const llama_token * input_tokens,
|
||||
llama_token * output_tokens,
|
||||
int32_t n_input,
|
||||
int32_t max_length,
|
||||
struct diffusion_params params,
|
||||
int32_t & n_generated) {
|
||||
|
||||
n_generated = 0;
|
||||
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
||||
// Initialize with input and pad with mask tokens
|
||||
std::copy(input_tokens, input_tokens + n_input, output_tokens);
|
||||
std::fill(output_tokens + n_input, output_tokens + max_length, params.mask_token_id);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
|
||||
std::vector<float> timesteps(params.steps + 1);
|
||||
for (int32_t i = 0; i <= params.steps; i++) {
|
||||
timesteps[i] = 1.0f - (float) i / params.steps * (1.0f - params.eps);
|
||||
}
|
||||
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
||||
|
||||
std::vector<llama_token_data> candidates(n_vocab);
|
||||
|
||||
std::vector<llama_token_data> conf_candidates;
|
||||
conf_candidates.reserve(max_length);
|
||||
|
||||
std::vector<int32_t> mask_positions;
|
||||
mask_positions.reserve(max_length);
|
||||
|
||||
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
|
||||
if (params.top_k > 0) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
|
||||
}
|
||||
if (params.top_p < 1.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
|
||||
}
|
||||
if (params.temperature > 0.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
|
||||
}
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
|
||||
|
||||
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
|
||||
|
||||
llama_batch batch = llama_batch_init(max_length, 0, 1);
|
||||
batch.n_tokens = max_length;
|
||||
|
||||
int64_t total_sampling_time = 0;
|
||||
int64_t total_time = 0;
|
||||
|
||||
int64_t time_start = ggml_time_us();
|
||||
for (int32_t step = 0; step < params.steps; step++) {
|
||||
if (params.step_callback) {
|
||||
if (!params.step_callback(step, params.steps, output_tokens, max_length, params.step_callback_user_data)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < max_length; i++) {
|
||||
batch.token[i] = output_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = 1;
|
||||
}
|
||||
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
float * raw_logits = llama_get_logits(ctx);
|
||||
if (!raw_logits) {
|
||||
LOG_ERR("%s: failed to get logits at step %d\n", __func__, step);
|
||||
break;
|
||||
}
|
||||
|
||||
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
|
||||
return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab;
|
||||
};
|
||||
|
||||
int64_t time_start_sampling = ggml_time_us();
|
||||
|
||||
mask_positions.clear();
|
||||
for (int32_t i = 0; i < max_length; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
mask_positions.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_positions.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
float t = timesteps[step];
|
||||
float s = timesteps[step + 1];
|
||||
|
||||
if (params.algorithm == DIFFUSION_ALG_ORIGIN) {
|
||||
float p_transfer = (step < params.steps - 1) ? (1.0f - s / t) : 1.0f;
|
||||
|
||||
for (int32_t pos : mask_positions) {
|
||||
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].id = token_id;
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
/* .data = */ candidates.data(),
|
||||
/* .size = */ (size_t) n_vocab, // Reset size to full vocab
|
||||
/* .selected = */ -1,
|
||||
/* .sorted = */ false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
output_tokens[pos] = cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::vector<std::pair<float, int32_t>> confidences;
|
||||
std::vector<llama_token> sampled_tokens(mask_positions.size());
|
||||
|
||||
for (size_t i = 0; i < mask_positions.size(); i++) {
|
||||
int32_t pos = mask_positions[i];
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
candidates[token_id].id = token_id;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
/* .data = */ candidates.data(),
|
||||
/* .size = */ candidates.size(),
|
||||
/* .selected = */ -1,
|
||||
/* .sorted = */ false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
|
||||
llama_token sampled_token = cur_p.data[cur_p.selected].id;
|
||||
|
||||
float confidence = 0.0f;
|
||||
if (params.algorithm == DIFFUSION_ALG_ENTROPY) {
|
||||
const float epsilon = 1e-10f;
|
||||
for (size_t j = 0; j < cur_p.size; j++) {
|
||||
float prob = cur_p.data[j].p;
|
||||
confidence += prob * logf(prob + epsilon);
|
||||
}
|
||||
} else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
|
||||
confidence = cur_p.data[0].p - cur_p.data[1].p;
|
||||
} else {
|
||||
confidence = cur_p.data[cur_p.selected].p;
|
||||
}
|
||||
|
||||
sampled_tokens[i] = sampled_token;
|
||||
confidences.emplace_back(confidence, i);
|
||||
}
|
||||
|
||||
int32_t num_transfer =
|
||||
(step < params.steps - 1) ? (int32_t) (mask_positions.size() * (1.0f - s / t)) : mask_positions.size();
|
||||
|
||||
if (num_transfer > 0) {
|
||||
if (params.alg_temp == 0.0f) {
|
||||
std::partial_sort(confidences.begin(), confidences.begin() + num_transfer, confidences.end(),
|
||||
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
|
||||
if (a.first != b.first) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
return a.second < b.second;
|
||||
});
|
||||
} else {
|
||||
conf_candidates.clear();
|
||||
|
||||
for (int32_t pos = 0; pos < max_length; pos++) {
|
||||
float conf_logit = -std::numeric_limits<float>::infinity();
|
||||
|
||||
auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
|
||||
if (it != mask_positions.end()) {
|
||||
size_t mask_idx = std::distance(mask_positions.begin(), it);
|
||||
conf_logit = confidences[mask_idx].first / params.alg_temp; // Apply temperature scaling
|
||||
}
|
||||
|
||||
conf_candidates.emplace_back(llama_token_data{ pos, conf_logit, 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array conf_array = {
|
||||
/* .data = */ conf_candidates.data(),
|
||||
/* .size = */ conf_candidates.size(),
|
||||
/* .selected = */ -1,
|
||||
/* .sorted = */ false,
|
||||
};
|
||||
|
||||
for (int32_t i = 0; i < num_transfer; i++) {
|
||||
// Apply distribution sampler to get selected index
|
||||
llama_sampler_apply(dist_sampler, &conf_array);
|
||||
int selected_idx = conf_array.selected;
|
||||
confidences[i].second = conf_candidates[selected_idx].id;
|
||||
|
||||
conf_candidates[selected_idx].p = 0.0f;
|
||||
conf_array.selected = -1;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.alg_temp == 0.0f) {
|
||||
// Deterministic - use confidence order
|
||||
for (int32_t i = 0; i < num_transfer; i++) {
|
||||
int32_t mask_idx = confidences[i].second;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
llama_token token = sampled_tokens[mask_idx];
|
||||
output_tokens[pos] = token;
|
||||
}
|
||||
} else {
|
||||
for (int32_t i = 0; i < num_transfer; i++) {
|
||||
int32_t pos = confidences[i].second;
|
||||
auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
|
||||
if (it != mask_positions.end()) {
|
||||
int32_t mask_idx = std::distance(mask_positions.begin(), it);
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
int64_t time_end_sampling = ggml_time_us();
|
||||
total_sampling_time += time_end_sampling - time_start_sampling;
|
||||
}
|
||||
int64_t time_end = ggml_time_us();
|
||||
total_time += time_end - time_start;
|
||||
|
||||
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
|
||||
total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps);
|
||||
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_sampler_free(sampler);
|
||||
llama_sampler_free(dist_sampler);
|
||||
|
||||
n_generated = max_length;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
|
||||
if (!use_chat_template) {
|
||||
return prompt;
|
||||
}
|
||||
|
||||
auto chat_templates = common_chat_templates_init(model, "");
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
common_chat_msg user_msg;
|
||||
user_msg.role = "user";
|
||||
user_msg.content = prompt;
|
||||
inputs.add_generation_prompt = true;
|
||||
inputs.messages.push_back(user_msg);
|
||||
|
||||
auto result = common_chat_templates_apply(chat_templates.get(), inputs);
|
||||
|
||||
return result.prompt;
|
||||
}
|
||||
|
||||
struct callback_data {
|
||||
const common_params_diffusion * diff_params;
|
||||
const llama_vocab * vocab;
|
||||
int32_t n_input;
|
||||
};
|
||||
|
||||
static bool diffusion_step_callback(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
void * user_data) {
|
||||
(void)user_data;
|
||||
|
||||
callback_data * data = static_cast<callback_data *>(user_data);
|
||||
|
||||
auto print_progress_bar = [](int32_t step, int32_t total_steps) {
|
||||
int progress_percent = (step * 100) / total_steps;
|
||||
int progress_bars = (step * 50) / total_steps;
|
||||
LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
|
||||
step,
|
||||
total_steps,
|
||||
std::string(progress_bars, '=').c_str(),
|
||||
std::string(50 - progress_bars, ' ').c_str(),
|
||||
progress_percent);
|
||||
};
|
||||
|
||||
if (data->diff_params->visual_mode) {
|
||||
// Visual mode: clear
|
||||
LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
|
||||
|
||||
print_progress_bar(step, total_steps);
|
||||
|
||||
LOG_INF("\n");
|
||||
|
||||
std::string current_text = " ";
|
||||
|
||||
for (int32_t i = data->n_input; i < n_tokens; i++) {
|
||||
std::string token_str;
|
||||
if (tokens[i] != llama_vocab_mask(data->vocab)) {
|
||||
char piece[256];
|
||||
int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
|
||||
if (n_chars > 0) {
|
||||
piece[n_chars] = '\0';
|
||||
token_str = piece;
|
||||
}
|
||||
} else {
|
||||
token_str = " ";
|
||||
}
|
||||
|
||||
current_text += token_str;
|
||||
}
|
||||
|
||||
LOG_INF("%s\n", current_text.c_str());
|
||||
} else {
|
||||
print_progress_bar(step, total_steps);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_time_init();
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const char * alg_names[] = { "ORIGIN", "MASKGIT_PLUS", "TOPK_MARGIN", "ENTROPY" };
|
||||
const char * alg_name = (params.diffusion.algorithm >= 0 && params.diffusion.algorithm <= 3) ?
|
||||
alg_names[params.diffusion.algorithm] :
|
||||
"UNKNOWN";
|
||||
|
||||
common_init();
|
||||
llama_backend_init();
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_params.n_gpu_layers = params.n_gpu_layers;
|
||||
model_params.devices = params.devices.data();
|
||||
model_params.use_mmap = params.use_mmap;
|
||||
model_params.use_mlock = params.use_mlock;
|
||||
model_params.check_tensors = params.check_tensors;
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
|
||||
if (!model) {
|
||||
LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_ctx = params.n_ctx;
|
||||
ctx_params.n_batch = params.n_batch;
|
||||
ctx_params.n_ubatch = params.n_ubatch;
|
||||
ctx_params.flash_attn = params.flash_attn;
|
||||
ctx_params.no_perf = params.no_perf;
|
||||
ctx_params.type_k = params.cache_type_k;
|
||||
ctx_params.type_v = params.cache_type_v;
|
||||
|
||||
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
||||
if (!ctx) {
|
||||
LOG_ERR("error: failed to create context\n");
|
||||
llama_model_free(model);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
|
||||
|
||||
std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt,
|
||||
/*add special tokens*/ true,
|
||||
/*parse special*/ true);
|
||||
int n_input = input_tokens.size();
|
||||
|
||||
if (n_input >= params.n_ctx) {
|
||||
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct diffusion_params ldiff_params = diffusion_default_params();
|
||||
ldiff_params.steps = params.diffusion.steps;
|
||||
ldiff_params.eps = params.diffusion.eps;
|
||||
ldiff_params.temperature = params.sampling.temp;
|
||||
ldiff_params.top_p = params.sampling.top_p;
|
||||
ldiff_params.top_k = params.sampling.top_k;
|
||||
ldiff_params.algorithm = static_cast<enum diffusion_alg>(params.diffusion.algorithm);
|
||||
ldiff_params.alg_temp = params.diffusion.alg_temp;
|
||||
ldiff_params.seed = params.sampling.seed;
|
||||
|
||||
llama_token mask_token_id = llama_vocab_mask(vocab);
|
||||
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
|
||||
|
||||
LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
|
||||
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", params.diffusion.steps);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", params.diffusion.eps);
|
||||
LOG_INF("diffusion_params: - %-25s u32 = %d (%s)\n", "algorithm", params.diffusion.algorithm,
|
||||
alg_name);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", params.diffusion.alg_temp);
|
||||
|
||||
ldiff_params.mask_token_id = mask_token_id;
|
||||
|
||||
callback_data cb_data = { ¶ms.diffusion, vocab, n_input };
|
||||
|
||||
ldiff_params.step_callback = diffusion_step_callback;
|
||||
ldiff_params.step_callback_user_data = &cb_data;
|
||||
|
||||
int32_t n_generated = 0;
|
||||
|
||||
std::vector<llama_token> output_tokens(params.n_ubatch);
|
||||
diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, params.n_ubatch,
|
||||
ldiff_params, n_generated);
|
||||
|
||||
if (n_generated > 0) {
|
||||
if (params.diffusion.visual_mode) {
|
||||
//clear screen and move cursor to top-left
|
||||
LOG_INF("\033[2J\033[H");
|
||||
}
|
||||
output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
|
||||
std::string output_data = common_detokenize(vocab, output_tokens, false);
|
||||
LOG_INF("\n%s\n", output_data.c_str());
|
||||
} else {
|
||||
LOG_INF("Error: diffusion generation failed\n");
|
||||
}
|
||||
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -107,7 +107,7 @@ int main(int argc, char ** argv) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const int n_ctx_train = llama_model_n_ctx_train(model);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
|
||||
|
||||
@@ -184,6 +184,9 @@ int main(int argc, char ** argv) {
|
||||
// extra text to insert in each client's prompt in order to make it larger
|
||||
const int32_t n_junk = std::max(1, params.n_junk);
|
||||
|
||||
// signed seed, use negative values to indicate different seeds for the different clients
|
||||
const int32_t & sseed = params.sampling.seed;
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
@@ -219,11 +222,21 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
if (sseed >= 0) {
|
||||
LOG_INF("%s: initializing all samplers with the same RNG seed: %d (use a negative seed to have different seeds)\n", __func__, sseed);
|
||||
} else {
|
||||
LOG_INF("%s: initializing samplers with different RNG seeds, starting from %d\n", __func__, sseed);
|
||||
}
|
||||
|
||||
std::vector<client> clients(n_clients);
|
||||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.smpl = common_sampler_init(model, params.sampling);
|
||||
|
||||
if (sseed < 0) {
|
||||
params.sampling.seed--;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
@@ -345,7 +358,7 @@ int main(int argc, char ** argv) {
|
||||
client.n_decoded = 0;
|
||||
client.i_batch = batch.n_tokens - 1;
|
||||
|
||||
LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur);
|
||||
LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, prompt = %d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur, client.n_prompt);
|
||||
|
||||
g_seq_id += 1;
|
||||
|
||||
|
||||
@@ -131,7 +131,7 @@ option(GGML_RVV "ggml: enable rvv" ON)
|
||||
option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF)
|
||||
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
|
||||
option(GGML_VXE "ggml: enable vxe" ON)
|
||||
option(GGML_NNPA "ggml: enable nnpa" ON)
|
||||
option(GGML_NNPA "ggml: enable nnpa" OFF) # temp disabled by default, see: https://github.com/ggml-org/llama.cpp/issues/14877
|
||||
|
||||
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
@@ -174,6 +174,8 @@ option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental,
|
||||
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
|
||||
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
|
||||
option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
|
||||
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
|
||||
option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF)
|
||||
option(GGML_VULKAN "ggml: use Vulkan" OFF)
|
||||
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
|
||||
option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)
|
||||
@@ -181,6 +183,8 @@ option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug ou
|
||||
option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF)
|
||||
option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF)
|
||||
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
|
||||
option(GGML_WEBGPU "ggml: use WebGPU" OFF)
|
||||
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
|
||||
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
||||
option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
|
||||
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
||||
@@ -270,6 +274,7 @@ set(GGML_PUBLIC_HEADERS
|
||||
include/ggml-rpc.h
|
||||
include/ggml-sycl.h
|
||||
include/ggml-vulkan.h
|
||||
include/ggml-webgpu.h
|
||||
include/gguf.h)
|
||||
|
||||
set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
|
||||
|
||||
@@ -1,152 +1,189 @@
|
||||
@PACKAGE_INIT@
|
||||
|
||||
@GGML_VARIABLES_EXPANDED@
|
||||
|
||||
@PACKAGE_INIT@
|
||||
|
||||
set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@")
|
||||
set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@")
|
||||
#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
find_library(GGML_LIBRARY ggml
|
||||
REQUIRED
|
||||
HINTS ${GGML_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
add_library(ggml::ggml UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::ggml
|
||||
PROPERTIES
|
||||
IMPORTED_LOCATION "${GGML_LIBRARY}")
|
||||
|
||||
find_library(GGML_BASE_LIBRARY ggml-base
|
||||
REQUIRED
|
||||
HINTS ${GGML_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
add_library(ggml::ggml-base UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::ggml-base
|
||||
PROPERTIES
|
||||
IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")
|
||||
|
||||
# Find all dependencies before creating any target.
|
||||
include(CMakeFindDependencyMacro)
|
||||
find_dependency(Threads)
|
||||
if (NOT GGML_SHARED_LIB)
|
||||
set(GGML_CPU_INTERFACE_LINK_LIBRARIES "")
|
||||
set(GGML_CPU_INTERFACE_LINK_OPTIONS "")
|
||||
|
||||
if (APPLE AND GGML_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
||||
if(NOT ACCELERATE_FRAMEWORK)
|
||||
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
|
||||
return()
|
||||
endif()
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK})
|
||||
endif()
|
||||
|
||||
if (GGML_OPENMP)
|
||||
find_package(OpenMP REQUIRED)
|
||||
if (GGML_OPENMP_ENABLED)
|
||||
find_dependency(OpenMP)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
||||
endif()
|
||||
|
||||
if (GGML_CPU_HBM)
|
||||
find_library(memkind memkind REQUIRED)
|
||||
find_library(memkind memkind)
|
||||
if(NOT memkind)
|
||||
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
|
||||
return()
|
||||
endif()
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind)
|
||||
endif()
|
||||
|
||||
if (GGML_BLAS)
|
||||
find_package(BLAS REQUIRED)
|
||||
find_dependency(BLAS)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS})
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
set(GGML_CUDA_INTERFACE_LINK_LIBRARIES "")
|
||||
find_dependency(CUDAToolkit)
|
||||
if (GGML_STATIC)
|
||||
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cudart_static>)
|
||||
if (WIN32)
|
||||
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cublas> $<LINK_ONLY:CUDA::cublasLt>)
|
||||
else()
|
||||
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cublas_static> $<LINK_ONLY:CUDA::cublasLt_static>)
|
||||
endif()
|
||||
endif()
|
||||
if (NOT GGML_CUDA_NO_VMM)
|
||||
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cuda_driver>)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (GGML_METAL)
|
||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||
find_library(FOUNDATION_LIBRARY Foundation)
|
||||
find_library(METAL_FRAMEWORK Metal)
|
||||
find_library(METALKIT_FRAMEWORK MetalKit)
|
||||
if(NOT FOUNDATION_LIBRARY OR NOT METAL_FRAMEWORK OR NOT METALKIT_FRAMEWORK)
|
||||
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
|
||||
return()
|
||||
endif()
|
||||
set(GGML_METAL_INTERFACE_LINK_LIBRARIES
|
||||
${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
|
||||
endif()
|
||||
|
||||
list(APPEND GGML_METAL_INTERFACE_LINK_LIBRARIES
|
||||
${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
|
||||
if (GGML_OPENCL)
|
||||
find_dependency(OpenCL)
|
||||
set(GGML_OPENCL_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:OpenCL::OpenCL>)
|
||||
endif()
|
||||
|
||||
if (GGML_VULKAN)
|
||||
find_package(Vulkan REQUIRED)
|
||||
list(APPEND GGML_VULKAN_INTERFACE_LINK_LIBRARIES Vulkan::Vulkan)
|
||||
find_dependency(Vulkan)
|
||||
set(GGML_VULKAN_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:Vulkan::Vulkan>)
|
||||
endif()
|
||||
|
||||
if (GGML_HIP)
|
||||
find_package(hip REQUIRED)
|
||||
find_package(hipblas REQUIRED)
|
||||
find_package(rocblas REQUIRED)
|
||||
list(APPEND GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas)
|
||||
find_dependency(hip)
|
||||
find_dependency(hipblas)
|
||||
find_dependency(rocblas)
|
||||
set(GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas)
|
||||
endif()
|
||||
|
||||
if (GGML_SYCL)
|
||||
set(GGML_SYCL_INTERFACE_LINK_LIBRARIES "")
|
||||
find_package(DNNL)
|
||||
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl)
|
||||
endif()
|
||||
if (WIN32)
|
||||
find_package(IntelSYCL REQUIRED)
|
||||
find_package(MKL REQUIRED)
|
||||
find_dependency(IntelSYCL)
|
||||
find_dependency(MKL)
|
||||
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(_ggml_all_targets "")
|
||||
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
|
||||
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
|
||||
string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
|
||||
set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@")
|
||||
set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@")
|
||||
#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
|
||||
|
||||
find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
|
||||
if(NOT TARGET ggml::ggml)
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
find_library(GGML_LIBRARY ggml
|
||||
REQUIRED
|
||||
HINTS ${GGML_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
|
||||
|
||||
add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
add_library(ggml::ggml UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::ggml
|
||||
PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
|
||||
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
|
||||
IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
|
||||
INTERFACE_COMPILE_FEATURES c_std_90
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
IMPORTED_LOCATION "${GGML_LIBRARY}")
|
||||
|
||||
string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
|
||||
if(is_cpu_variant)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
|
||||
find_library(GGML_BASE_LIBRARY ggml-base
|
||||
REQUIRED
|
||||
HINTS ${GGML_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
if(GGML_CPU_INTERFACE_LINK_OPTIONS)
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
|
||||
endif()
|
||||
add_library(ggml::ggml-base UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::ggml-base
|
||||
PROPERTIES
|
||||
IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")
|
||||
|
||||
else()
|
||||
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
|
||||
set(_ggml_all_targets "")
|
||||
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
|
||||
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
|
||||
string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
|
||||
|
||||
find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
|
||||
REQUIRED
|
||||
HINTS ${GGML_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
|
||||
|
||||
add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
|
||||
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
|
||||
IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
|
||||
INTERFACE_COMPILE_FEATURES c_std_90
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
|
||||
string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
|
||||
if(is_cpu_variant)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
|
||||
|
||||
if(GGML_CPU_INTERFACE_LINK_OPTIONS)
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
|
||||
endif()
|
||||
|
||||
else()
|
||||
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
|
||||
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
|
||||
|
||||
if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
list(APPEND _ggml_all_targets ggml::${_ggml_backend})
|
||||
endforeach()
|
||||
list(APPEND _ggml_all_targets ggml::${_ggml_backend})
|
||||
endforeach()
|
||||
|
||||
list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}")
|
||||
set_target_properties(ggml::ggml
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${GGML_INTERFACE_LINK_LIBRARIES}")
|
||||
list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}")
|
||||
set_target_properties(ggml::ggml
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${GGML_INTERFACE_LINK_LIBRARIES}")
|
||||
|
||||
add_library(ggml::all INTERFACE IMPORTED)
|
||||
set_target_properties(ggml::all
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}")
|
||||
add_library(ggml::all INTERFACE IMPORTED)
|
||||
set_target_properties(ggml::all
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}")
|
||||
|
||||
endif()
|
||||
|
||||
check_required_components(ggml)
|
||||
|
||||
19
ggml/include/ggml-webgpu.h
Normal file
19
ggml/include/ggml-webgpu.h
Normal file
@@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define GGML_WEBGPU_NAME "WebGPU"
|
||||
|
||||
// Needed for examples in ggml
|
||||
GGML_BACKEND_API ggml_backend_t ggml_backend_webgpu_init(void);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_webgpu_reg(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1297,6 +1297,19 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
float s);
|
||||
|
||||
// x = s * a + b
|
||||
GGML_API struct ggml_tensor * ggml_scale_bias(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s,
|
||||
float b);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_scale_bias_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s,
|
||||
float b);
|
||||
|
||||
// b -> view(a,offset,nb1,nb2,3), return modified a
|
||||
GGML_API struct ggml_tensor * ggml_set(
|
||||
struct ggml_context * ctx,
|
||||
|
||||
@@ -370,6 +370,7 @@ ggml_add_backend(MUSA)
|
||||
ggml_add_backend(RPC)
|
||||
ggml_add_backend(SYCL)
|
||||
ggml_add_backend(Vulkan)
|
||||
ggml_add_backend(WebGPU)
|
||||
ggml_add_backend(OpenCL)
|
||||
|
||||
foreach (target ggml-base ggml)
|
||||
|
||||
@@ -22,21 +22,6 @@ static bool ggml_is_view(const struct ggml_tensor * t) {
|
||||
return t->view_src != NULL;
|
||||
}
|
||||
|
||||
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
|
||||
if (a->type != b->type) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (a->ne[i] != b->ne[i]) {
|
||||
return false;
|
||||
}
|
||||
if (a->nb[i] != b->nb[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// ops that return true for this function must not use restrict pointers for their backend implementations
|
||||
static bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
switch (op) {
|
||||
|
||||
@@ -45,6 +45,10 @@
|
||||
#include "ggml-vulkan.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_WEBGPU
|
||||
#include "ggml-webgpu.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_OPENCL
|
||||
#include "ggml-opencl.h"
|
||||
#endif
|
||||
@@ -173,6 +177,9 @@ struct ggml_backend_registry {
|
||||
#ifdef GGML_USE_VULKAN
|
||||
register_backend(ggml_backend_vk_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_WEBGPU
|
||||
register_backend(ggml_backend_webgpu_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_OPENCL
|
||||
register_backend(ggml_backend_opencl_reg());
|
||||
#endif
|
||||
|
||||
@@ -352,21 +352,6 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
|
||||
|
||||
// backend copy
|
||||
|
||||
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
|
||||
if (a->type != b->type) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (a->ne[i] != b->ne[i]) {
|
||||
return false;
|
||||
}
|
||||
if (a->nb[i] != b->nb[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
|
||||
|
||||
@@ -662,6 +647,7 @@ struct ggml_backend_sched {
|
||||
// pipeline parallelism support
|
||||
int n_copies;
|
||||
int cur_copy;
|
||||
int next_copy;
|
||||
ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES];
|
||||
struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
|
||||
int n_graph_inputs;
|
||||
@@ -1448,8 +1434,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
}
|
||||
}
|
||||
|
||||
sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies;
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -1550,10 +1534,10 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
|
||||
bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
|
||||
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
|
||||
|
||||
ggml_backend_sched_split_graph(sched, measure_graph);
|
||||
|
||||
ggml_backend_sched_synchronize(sched);
|
||||
|
||||
ggml_backend_sched_split_graph(sched, measure_graph);
|
||||
|
||||
if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) {
|
||||
return false;
|
||||
}
|
||||
@@ -1565,6 +1549,10 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
|
||||
|
||||
bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
|
||||
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
|
||||
GGML_ASSERT(!sched->is_alloc);
|
||||
|
||||
sched->cur_copy = sched->next_copy;
|
||||
sched->next_copy = (sched->next_copy + 1) % sched->n_copies;
|
||||
|
||||
ggml_backend_sched_split_graph(sched, graph);
|
||||
|
||||
@@ -1605,7 +1593,7 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
|
||||
// if the graph is not already allocated, always use copy 0 after a synchronization
|
||||
// this ensures that during generation the same copy is used every time,
|
||||
// which avoids changes in the graph that could cause CUDA or other graphs to be disabled
|
||||
sched->cur_copy = 0;
|
||||
sched->next_copy = 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -77,6 +77,8 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
|
||||
for (int i = 0; i < final_dims; i++) {
|
||||
acl_storage_len += (acl_ne[i] - 1) * acl_stride[i];
|
||||
}
|
||||
size_t elem_offset = offset / ggml_element_size(tensor);
|
||||
acl_storage_len += elem_offset;
|
||||
|
||||
// Reverse ne and stride.
|
||||
std::reverse(acl_ne, acl_ne + final_dims);
|
||||
@@ -84,7 +86,7 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
|
||||
|
||||
aclTensor* acl_tensor = aclCreateTensor(
|
||||
acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
|
||||
offset / ggml_element_size(tensor), format, &acl_storage_len, 1,
|
||||
elem_offset, format, &acl_storage_len, 1,
|
||||
tensor->data);
|
||||
|
||||
return acl_tensor;
|
||||
|
||||
@@ -99,7 +99,7 @@ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclT
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cann_unary_op(
|
||||
void ggml_cann_op_unary(
|
||||
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
|
||||
ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_tensor* src = dst->src[0];
|
||||
@@ -111,6 +111,42 @@ void ggml_cann_unary_op(
|
||||
ggml_cann_release_resources(ctx, acl_src, acl_dst);
|
||||
}
|
||||
|
||||
void ggml_cann_op_unary_gated(
|
||||
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
|
||||
ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_tensor* src0 = dst->src[0];
|
||||
ggml_tensor* src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
|
||||
aclTensor *acl_src0 = nullptr, *acl_src1 = nullptr;
|
||||
if(src1) {
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||
GGML_ASSERT(src0->type == src1->type);
|
||||
|
||||
acl_src0 = ggml_cann_create_tensor(src0);
|
||||
acl_src1 = ggml_cann_create_tensor(src1);
|
||||
} else {
|
||||
int64_t ne[] = {src0->ne[0] / 2, src0->ne[1], src0->ne[2], src0->ne[3]};
|
||||
size_t nb[] = {src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]};
|
||||
acl_src0 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, 0);
|
||||
acl_src1 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, ne[0] * ggml_element_size(src0));
|
||||
if (swapped) {
|
||||
std::swap(acl_src0, acl_src1);
|
||||
}
|
||||
}
|
||||
|
||||
unary_op(ctx, acl_src0, acl_dst);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1);
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src0, acl_dst);
|
||||
if(src1)
|
||||
ggml_cann_release_resources(ctx, acl_src1);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Repeats elements of a tensor along each dimension according to the
|
||||
* specified repeat array.
|
||||
@@ -1785,8 +1821,27 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
|
||||
size_t transpose_nb[] = {bcast_weight_nb[1], bcast_weight_nb[0],
|
||||
bcast_weight_nb[2], bcast_weight_nb[3],
|
||||
bcast_weight_nb[4], bcast_weight_nb[5]};
|
||||
aclTensor* acl_weight_tensor =
|
||||
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims);
|
||||
aclTensor* acl_weight_tensor;
|
||||
|
||||
bool weightToNZ = false;
|
||||
#ifdef ASCEND_310P
|
||||
weightToNZ = (getenv("GGML_CANN_WEIGHT_NZ") != nullptr);
|
||||
#endif
|
||||
if (weightToNZ && is_matmul_weight(weight)) {
|
||||
int64_t acl_stride[2] = {1, transpose_ne[1]};
|
||||
|
||||
// Reverse ne.
|
||||
std::reverse(transpose_ne, transpose_ne + n_dims);
|
||||
|
||||
std::vector<int64_t> storageDims = {transpose_ne[0], transpose_ne[1]};
|
||||
|
||||
acl_weight_tensor = aclCreateTensor(
|
||||
transpose_ne, n_dims, ggml_cann_type_mapping(weight->type), acl_stride,
|
||||
0, ACL_FORMAT_FRACTAL_NZ, storageDims.data(), 2, weight->data);
|
||||
} else {
|
||||
acl_weight_tensor =
|
||||
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
|
||||
}
|
||||
aclTensor* acl_dst =
|
||||
ggml_cann_create_tensor(dst, bcast_dst_ne, bcast_dst_nb, n_dims);
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#ifndef CANN_ACLNN_OPS
|
||||
#define CANN_ACLNN_OPS
|
||||
|
||||
#include <unordered_set>
|
||||
#include <functional>
|
||||
#include <aclnnop/aclnn_abs.h>
|
||||
#include <aclnnop/aclnn_neg.h>
|
||||
@@ -1020,6 +1021,37 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe
|
||||
*/
|
||||
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
|
||||
/**
|
||||
* @brief Check whether a tensor is a weight tensor for matrix multiplication.
|
||||
*
|
||||
* @details Checks whether the given tensor serves as weight parameters in matrix multiplication operations,
|
||||
* typically within neural network layers. The function maintains a static set of canonical weight
|
||||
* naming suffixes from Transformer-based architectures. Uses substring matching to identify weight
|
||||
* tensors even with hierarchical naming patterns.
|
||||
*
|
||||
* @param tensor Pointer to the target ggml_tensor object (const-qualified).
|
||||
*/
|
||||
static bool is_matmul_weight(const ggml_tensor* tensor) {
|
||||
std::string name = ggml_get_name(tensor);
|
||||
static const std::unordered_set<std::string> weight_suffixes{
|
||||
"output.weight",
|
||||
"attn_q.weight",
|
||||
"attn_k.weight",
|
||||
"attn_v.weight",
|
||||
"attn_output.weight",
|
||||
"ffn_gate.weight",
|
||||
"ffn_up.weight",
|
||||
"ffn_down.weight"
|
||||
};
|
||||
|
||||
for (const auto& suffix : weight_suffixes) {
|
||||
if (name.find(suffix) != std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies a element-wise operation to two input tensors using the CANN
|
||||
* backend.
|
||||
@@ -1066,7 +1098,7 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
* @param dst The destination tensor. Its src[0] is treated as the input tensor.
|
||||
*/
|
||||
template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
|
||||
void ggml_cann_unary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
void ggml_cann_op_unary(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_tensor* src = dst->src[0];
|
||||
|
||||
aclTensor* acl_src = ggml_cann_create_tensor(src);
|
||||
@@ -1077,49 +1109,125 @@ template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies a unary operation to a ggml tensor using the CANN backend.
|
||||
* @brief Applies a unary operation to a ggml tensor using the CANN backend.
|
||||
*
|
||||
* @details This function performs a unary operation on the input tensor using
|
||||
* a user-provided lambda or callable object `unary_op`, which accepts the CANN
|
||||
* context and two ACL tensors (source and destination). Internally, this function
|
||||
* creates ACL representations of the ggml tensors and invokes the unary operation.
|
||||
* The result is stored in the destination tensor `dst`. This utility abstracts the
|
||||
* common boilerplate of tensor conversion and cleanup when implementing unary ops.
|
||||
* @details This function applies a unary operation to the input tensor using
|
||||
* a user-provided lambda or callable `unary_op`. The lambda receives the
|
||||
* CANN backend context and two ACL tensors: the source and the destination.
|
||||
*
|
||||
* @param unary_op A callable that performs the unary operation using CANN APIs.
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the result will be stored.
|
||||
* The source tensor is retrieved from `dst->src[0]`.
|
||||
* Internally, this function handles the conversion from GGML tensors to ACL tensors,
|
||||
* calls the provided unary op, and manages resource cleanup. The input is assumed
|
||||
* to be `dst->src[0]`, and the result is written to `dst`.
|
||||
*
|
||||
* This utility simplifies writing unary op wrappers by abstracting tensor preparation.
|
||||
*
|
||||
* @param unary_op A callable that performs the unary operation using CANN ACL APIs.
|
||||
* @param ctx The CANN context for operation execution.
|
||||
* @param dst The destination ggml_tensor where the result will be stored.
|
||||
* The input tensor is assumed to be `dst->src[0]`.
|
||||
*
|
||||
* @see GGML_CANN_CALL_OP_UNARY
|
||||
*/
|
||||
void ggml_cann_unary_op(
|
||||
void ggml_cann_op_unary(
|
||||
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
|
||||
ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
|
||||
/**
|
||||
* @brief Helper macro to invoke a unary ACL operation using ggml_cann_unary_op.
|
||||
* @brief Applies a gated (GLU-style) unary operation using the CANN backend.
|
||||
*
|
||||
* This macro defines an inline lambda wrapping a specific ACL operation name,
|
||||
* and passes it to the templated ggml_cann_unary_op function. It simplifies
|
||||
* calling unary ops by hiding the lambda boilerplate.
|
||||
* @details This function performs a gated activation such as GEGLU or ReGLU.
|
||||
* It supports two input modes:
|
||||
*
|
||||
* 1. **Dual input mode**: `dst->src[0]` and `dst->src[1]` are both valid tensors.
|
||||
* These are used directly as the value and gate tensors.
|
||||
*
|
||||
* 2. **Packed input mode**: Only `dst->src[0]` is valid, and it is assumed to
|
||||
* contain a concatenation of value and gate along the first dimension. This tensor
|
||||
* will be split into two equal halves to form the value and gate inputs.
|
||||
*
|
||||
* The function applies a user-provided unary operation (e.g., GELU) to the value tensor,
|
||||
* then multiplies the result in-place with the gate tensor:
|
||||
*
|
||||
* Internally, the lambda will call:
|
||||
* @code
|
||||
* GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
|
||||
* dst = unary_op(value) * gate;
|
||||
* @endcode
|
||||
*
|
||||
* The `swapped` parameter (from `dst->op_params[1]`) allows flipping the
|
||||
* order of value/gate in the packed input case.
|
||||
*
|
||||
* @param unary_op A callable that performs the unary operation using CANN ACL APIs.
|
||||
* It receives (ctx, acl_value_tensor, acl_output_tensor).
|
||||
* @param ctx The CANN context used for execution.
|
||||
* @param dst The destination ggml_tensor. Source tensors are in `dst->src[0]` and optionally `src[1]`.
|
||||
*
|
||||
* @see GGML_CANN_CALL_OP_UNARY_GATED
|
||||
*/
|
||||
void ggml_cann_op_unary_gated(
|
||||
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
|
||||
ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
|
||||
/**
|
||||
* @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary.
|
||||
*
|
||||
* This macro wraps the specified ACLNN unary operator name into a lambda expression,
|
||||
* and passes it to `ggml_cann_op_unary`, which handles the common logic for executing
|
||||
* unary ops in the CANN backend.
|
||||
*
|
||||
* Internally, this macro expands to a lambda like:
|
||||
* @code
|
||||
* [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
|
||||
* GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
|
||||
* };
|
||||
* @endcode
|
||||
*
|
||||
* This lambda is then passed to `ggml_cann_op_unary`, which applies the operation.
|
||||
*
|
||||
* @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
|
||||
*
|
||||
* @see ggml_cann_unary_op
|
||||
* @see ggml_cann_op_unary
|
||||
* @see GGML_CANN_CALL_ACLNN_OP
|
||||
*/
|
||||
#define GGML_CANN_CALL_UNARY_OP(OP_NAME) \
|
||||
#define GGML_CANN_CALL_OP_UNARY(OP_NAME) \
|
||||
do { \
|
||||
auto lambda = [](ggml_backend_cann_context& ctx, \
|
||||
aclTensor* acl_src, \
|
||||
aclTensor* acl_dst) { \
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
|
||||
}; \
|
||||
ggml_cann_unary_op(lambda, ctx, dst); \
|
||||
ggml_cann_op_unary(lambda, ctx, dst); \
|
||||
} \
|
||||
while (0)
|
||||
|
||||
/**
|
||||
* @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated.
|
||||
*
|
||||
* This macro wraps the specified ACLNN unary operator name into a lambda expression,
|
||||
* and passes it to `ggml_cann_op_unary_gated`, which handles the common logic for
|
||||
* executing gated unary ops in the CANN backend.
|
||||
*
|
||||
* Internally, this macro expands to a lambda like:
|
||||
* @code
|
||||
* [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
|
||||
* GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
|
||||
* };
|
||||
* @endcode
|
||||
*
|
||||
* This lambda is then passed to `ggml_cann_op_unary_gated`, which applies the operation.
|
||||
*
|
||||
* @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
|
||||
*
|
||||
* @see ggml_cann_op_unary_gated
|
||||
* @see GGML_CANN_CALL_ACLNN_OP
|
||||
*/
|
||||
#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \
|
||||
do { \
|
||||
auto lambda = [](ggml_backend_cann_context& ctx, \
|
||||
aclTensor* acl_src, \
|
||||
aclTensor* acl_dst) { \
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
|
||||
}; \
|
||||
ggml_cann_op_unary_gated(lambda, ctx, dst); \
|
||||
} \
|
||||
while (0)
|
||||
|
||||
#endif // CANN_ACLNN_OPS
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
|
||||
#include <acl/acl.h>
|
||||
#include <stdarg.h>
|
||||
#include <aclnnop/aclnn_trans_matmul_weight.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
@@ -1115,6 +1116,63 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
static int CreateAclTensorWeight(const void *hostData, const std::vector<int64_t> &shape, void **deviceAddr,
|
||||
aclDataType dataType, aclTensor **tensor)
|
||||
{
|
||||
uint64_t size = 1;
|
||||
for (auto i : shape) {
|
||||
size *= i;
|
||||
}
|
||||
|
||||
const aclIntArray *mat2Size = aclCreateIntArray(shape.data(), shape.size());
|
||||
ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(mat2Size, dataType, &size));
|
||||
|
||||
size *= sizeof(int16_t);
|
||||
|
||||
ACL_CHECK(aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
aclrtMemcpy(*deviceAddr, size, hostData, size, ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
std::vector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t i = shape.size() - 2; i >= 0; i--) {
|
||||
strides[i] = shape[i + 1] * strides[i + 1];
|
||||
}
|
||||
|
||||
*tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
|
||||
shape.data(), shape.size(), *deviceAddr);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void weight_format_to_nz(ggml_tensor *tensor, const void *data, size_t offset) {
|
||||
aclrtStream stream;
|
||||
ACL_CHECK(aclrtCreateStream(&stream));
|
||||
|
||||
std::vector<int64_t> weightTransposedShape = {tensor->ne[1], tensor->ne[0]};
|
||||
void *weightTransposedDeviceAddr = nullptr;
|
||||
aclTensor *weightTransposed = nullptr;
|
||||
CreateAclTensorWeight(data, weightTransposedShape, &weightTransposedDeviceAddr,
|
||||
ggml_cann_type_mapping(tensor->type), &weightTransposed);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
void *workspaceAddr = nullptr;
|
||||
|
||||
// TransMatmulWeight
|
||||
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, &workspaceSize, &executor));
|
||||
std::unique_ptr<void, aclError (*)(void *)> workspaceAddrPtrTrans(nullptr, aclrtFree);
|
||||
if (workspaceSize > 0) {
|
||||
ACL_CHECK(aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
workspaceAddrPtrTrans.reset(workspaceAddr);
|
||||
}
|
||||
ACL_CHECK(aclnnTransMatmulWeight(workspaceAddr, workspaceSize, executor, stream));
|
||||
|
||||
size_t size = ggml_nelements(tensor) * ggml_element_size(tensor);
|
||||
|
||||
aclrtMemcpy((char *)tensor->data + offset, size,
|
||||
weightTransposedDeviceAddr, size, ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
ACL_CHECK(aclDestroyTensor(weightTransposed));
|
||||
aclrtFree(weightTransposedDeviceAddr);
|
||||
}
|
||||
|
||||
// TODO: need handle tensor which has paddings.
|
||||
/**
|
||||
* @brief Set tensor data in a CANN buffer.
|
||||
@@ -1139,9 +1197,16 @@ static void ggml_backend_cann_buffer_set_tensor(
|
||||
// For acl, synchronous functions use this default stream.
|
||||
// Why aclrtSynchronizeDevice?
|
||||
|
||||
bool weightToNZ = false;
|
||||
#ifdef ASCEND_310P
|
||||
weightToNZ = (getenv("GGML_CANN_WEIGHT_NZ") != nullptr);
|
||||
#endif
|
||||
if (!need_transform(tensor->type)) {
|
||||
ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE));
|
||||
if (weightToNZ && is_matmul_weight((const ggml_tensor*)tensor)) {
|
||||
weight_format_to_nz(tensor, data, offset);
|
||||
}
|
||||
} else {
|
||||
void *transform_buffer = malloc(size);
|
||||
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
||||
@@ -1616,16 +1681,18 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(dst)) {
|
||||
case GGML_UNARY_OP_ABS:
|
||||
GGML_CANN_CALL_UNARY_OP(Abs);
|
||||
GGML_CANN_CALL_OP_UNARY(Abs);
|
||||
break;
|
||||
case GGML_UNARY_OP_NEG:
|
||||
GGML_CANN_CALL_UNARY_OP(Neg);
|
||||
GGML_CANN_CALL_OP_UNARY(Neg);
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU:
|
||||
GGML_CANN_CALL_UNARY_OP(Gelu);
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
// aclnnGelu internally uses the erf-based approximation.
|
||||
GGML_CANN_CALL_OP_UNARY(Gelu);
|
||||
break;
|
||||
case GGML_UNARY_OP_SILU:
|
||||
GGML_CANN_CALL_UNARY_OP(Silu);
|
||||
GGML_CANN_CALL_OP_UNARY(Silu);
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU_QUICK: {
|
||||
auto lambda = [](ggml_backend_cann_context& ctx,
|
||||
@@ -1633,31 +1700,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
||||
aclTensor* acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
|
||||
};
|
||||
ggml_cann_unary_op(lambda, ctx, dst);
|
||||
ggml_cann_op_unary(lambda, ctx, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_TANH:
|
||||
GGML_CANN_CALL_UNARY_OP(Tanh);
|
||||
GGML_CANN_CALL_OP_UNARY(Tanh);
|
||||
break;
|
||||
case GGML_UNARY_OP_RELU:
|
||||
GGML_CANN_CALL_UNARY_OP(Relu);
|
||||
GGML_CANN_CALL_OP_UNARY(Relu);
|
||||
break;
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
GGML_CANN_CALL_UNARY_OP(Sigmoid);
|
||||
GGML_CANN_CALL_OP_UNARY(Sigmoid);
|
||||
break;
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
GGML_CANN_CALL_UNARY_OP(Hardsigmoid);
|
||||
GGML_CANN_CALL_OP_UNARY(Hardsigmoid);
|
||||
break;
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
GGML_CANN_CALL_UNARY_OP(Hardswish);
|
||||
GGML_CANN_CALL_OP_UNARY(Hardswish);
|
||||
break;
|
||||
case GGML_UNARY_OP_EXP:
|
||||
GGML_CANN_CALL_UNARY_OP(Exp);
|
||||
GGML_CANN_CALL_OP_UNARY(Exp);
|
||||
break;
|
||||
case GGML_UNARY_OP_ELU:
|
||||
ggml_cann_elu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SGN:
|
||||
GGML_CANN_CALL_UNARY_OP(Sign);
|
||||
GGML_CANN_CALL_OP_UNARY(Sign);
|
||||
break;
|
||||
case GGML_UNARY_OP_STEP:
|
||||
ggml_cann_step(ctx, dst);
|
||||
@@ -1666,6 +1733,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_GLU:
|
||||
switch (ggml_get_glu_op(dst)) {
|
||||
case GGML_GLU_OP_REGLU:
|
||||
GGML_CANN_CALL_OP_UNARY_GATED(Relu);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
// aclnnGelu internally uses the erf-based approximation.
|
||||
GGML_CANN_CALL_OP_UNARY_GATED(Gelu);
|
||||
break;
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
GGML_CANN_CALL_OP_UNARY_GATED(Silu);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_QUICK: {
|
||||
auto lambda = [](ggml_backend_cann_context& ctx,
|
||||
aclTensor* acl_src,
|
||||
aclTensor* acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
|
||||
};
|
||||
ggml_cann_op_unary_gated(lambda, ctx, dst);
|
||||
} break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_NORM:
|
||||
ggml_cann_norm(ctx, dst);
|
||||
break;
|
||||
@@ -1708,7 +1800,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
||||
ggml_cann_binary_op<aclnn_mul>(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SQRT:
|
||||
GGML_CANN_CALL_UNARY_OP(Sqrt);
|
||||
GGML_CANN_CALL_OP_UNARY(Sqrt);
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
ggml_cann_clamp(ctx, dst);
|
||||
@@ -1753,16 +1845,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
||||
ggml_cann_argmax(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_COS:
|
||||
ggml_cann_unary_op<aclnn_cos>(ctx, dst);
|
||||
ggml_cann_op_unary<aclnn_cos>(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SIN:
|
||||
ggml_cann_unary_op<aclnn_sin>(ctx, dst);
|
||||
ggml_cann_op_unary<aclnn_sin>(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
ggml_cann_conv_transpose_1d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_LOG:
|
||||
GGML_CANN_CALL_UNARY_OP(Log);
|
||||
GGML_CANN_CALL_OP_UNARY(Log);
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
ggml_cann_mean(ctx, dst);
|
||||
@@ -2036,10 +2128,23 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_SGN:
|
||||
case GGML_UNARY_OP_STEP:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
case GGML_OP_GLU:
|
||||
switch (ggml_get_glu_op(op)) {
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT: {
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
@@ -2090,6 +2195,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
{
|
||||
// TODO: add support
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||
#pragma message("TODO: implement F32, F16, BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_CPY: {
|
||||
@@ -2188,7 +2294,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_CLAMP:
|
||||
@@ -2210,6 +2315,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
return true;
|
||||
case GGML_OP_SCALE:
|
||||
float bias;
|
||||
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
|
||||
return bias == 0.0f; // TODO: support bias != 0.0f
|
||||
case GGML_OP_SOFT_MAX:
|
||||
// TODO: support broadcast
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
||||
|
||||
@@ -70,10 +70,12 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
if (GGML_OPENMP)
|
||||
find_package(OpenMP)
|
||||
if (OpenMP_FOUND)
|
||||
set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "")
|
||||
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP)
|
||||
|
||||
target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
||||
else()
|
||||
set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "")
|
||||
message(WARNING "OpenMP not found")
|
||||
endif()
|
||||
endif()
|
||||
@@ -456,6 +458,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
list(APPEND ARCH_FLAGS -march=z16)
|
||||
elseif (${S390X_M} MATCHES "9175|9176")
|
||||
# NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.
|
||||
# binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15.
|
||||
message(STATUS "z17 target")
|
||||
list(APPEND ARCH_FLAGS -march=z17)
|
||||
else()
|
||||
@@ -494,9 +497,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
# Fetch KleidiAI sources:
|
||||
include(FetchContent)
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.9.0")
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.11.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "2a8e1bb55d201557553545536489a017")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2")
|
||||
|
||||
if (POLICY CMP0135)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
@@ -544,7 +544,7 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
|
||||
__m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) );
|
||||
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
|
||||
__m128 tmp = max4;
|
||||
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
|
||||
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x1 ));
|
||||
const float max_scalar = ((v4f32)max4)[0];
|
||||
|
||||
// Quantize these floats
|
||||
|
||||
@@ -22,9 +22,94 @@
|
||||
|
||||
#include "kai_common.h"
|
||||
|
||||
#include "simd-mappings.h"
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
||||
|
||||
static const size_t INT4_PER_BYTE = 2;
|
||||
static const size_t INT4_BITS = 4;
|
||||
static const int Q4_0_ZERO_POINT = 8;
|
||||
const size_t INT4_PER_UINT16 = 4;
|
||||
|
||||
static void dequantize_row_qsi4c32pscalef16(
|
||||
const void *packed_data,
|
||||
int32_t row_idx,
|
||||
int64_t nc,
|
||||
float *out,
|
||||
size_t nr_pack,
|
||||
size_t packed_row_stride,
|
||||
size_t kr,
|
||||
size_t bl,
|
||||
size_t num_bytes_multiplier
|
||||
) {
|
||||
size_t group_idx = row_idx / nr_pack;
|
||||
size_t row_in_group = row_idx % nr_pack;
|
||||
const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
|
||||
size_t num_blocks = nc / bl;
|
||||
const uint8_t *block_ptr = packed_group;
|
||||
|
||||
for (size_t b = 0; b < num_blocks; ++b) {
|
||||
uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));
|
||||
float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
|
||||
|
||||
const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;
|
||||
size_t num_segments = bl / kr;
|
||||
size_t num_bytes_per_segment = kr / INT4_PER_BYTE;
|
||||
|
||||
for (size_t s = 0; s < num_segments; ++s) {
|
||||
const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;
|
||||
const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;
|
||||
for (size_t k = 0; k < num_bytes_per_segment; ++k) {
|
||||
uint8_t byte = qbytes[k] ^ 0x88;
|
||||
int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT;
|
||||
int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;
|
||||
out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;
|
||||
out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale;
|
||||
}
|
||||
}
|
||||
block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_row_qsi4c32ps1s0scalef16(
|
||||
const void *packed_data,
|
||||
int32_t row_idx,
|
||||
int64_t k,
|
||||
float *out,
|
||||
size_t nr,
|
||||
size_t packed_row_stride,
|
||||
size_t kr,
|
||||
size_t bl,
|
||||
size_t num_bytes_multiplier
|
||||
) {
|
||||
const size_t num_blocks = k / bl;
|
||||
const size_t bl4 = bl / INT4_PER_UINT16;
|
||||
|
||||
size_t group_idx = row_idx / nr;
|
||||
size_t row_in_group = row_idx % nr;
|
||||
|
||||
const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
|
||||
const uint16_t *qdata = (const uint16_t *)packed_group;
|
||||
const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));
|
||||
|
||||
for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
|
||||
uint16_t scale_f16 = scales[row_in_group + block_idx * nr];
|
||||
float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
|
||||
|
||||
for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) {
|
||||
uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];
|
||||
|
||||
for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) {
|
||||
int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT;
|
||||
out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_UNUSED(kr);
|
||||
}
|
||||
|
||||
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
#if defined(__ARM_FEATURE_SME)
|
||||
{
|
||||
@@ -63,8 +148,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
/* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
@@ -107,8 +194,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
||||
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
||||
/* .packed_stride = */ NULL,
|
||||
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
||||
/* .to_float = */ NULL,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
@@ -154,8 +243,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
@@ -200,8 +291,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
@@ -247,8 +340,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
@@ -293,8 +388,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
|
||||
@@ -71,12 +71,15 @@ struct rhs_packing_info {
|
||||
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
|
||||
std::function<size_t(size_t n, size_t k)>
|
||||
> packed_size;
|
||||
size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
|
||||
std::variant<
|
||||
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
||||
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
|
||||
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
|
||||
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
|
||||
> pack_func;
|
||||
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
|
||||
size_t kr, size_t bl, size_t num_bytes_multiplier);
|
||||
};
|
||||
|
||||
struct ggml_kleidiai_kernels {
|
||||
|
||||
@@ -40,6 +40,17 @@ struct ggml_kleidiai_context {
|
||||
ggml_kleidiai_kernels * kernels;
|
||||
} static ctx = { CPU_FEATURE_NONE, NULL };
|
||||
|
||||
static const char* cpu_feature_to_string(cpu_feature f) {
|
||||
switch (f) {
|
||||
case CPU_FEATURE_NONE: return "NONE";
|
||||
case CPU_FEATURE_DOTPROD: return "DOTPROD";
|
||||
case CPU_FEATURE_I8MM: return "I8MM";
|
||||
case CPU_FEATURE_SVE: return "SVE";
|
||||
case CPU_FEATURE_SME: return "SME";
|
||||
default: return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
static void init_kleidiai_context(void) {
|
||||
|
||||
ggml_critical_section_start();
|
||||
@@ -62,6 +73,11 @@ static void init_kleidiai_context(void) {
|
||||
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
||||
}
|
||||
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
||||
#ifndef NDEBUG
|
||||
if (ctx.kernels) {
|
||||
GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
ggml_critical_section_end();
|
||||
}
|
||||
@@ -102,6 +118,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1
|
||||
|
||||
class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
||||
if (op->op != GGML_OP_MUL_MAT) {
|
||||
return false;
|
||||
}
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
||||
GGML_ASSERT(kernels);
|
||||
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
||||
@@ -135,6 +154,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
return compute_forward_kv_cache(params, dst);
|
||||
}
|
||||
} else if (dst->op == GGML_OP_GET_ROWS) {
|
||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||
return compute_forward_get_rows(params, dst);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -270,6 +293,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
}
|
||||
|
||||
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
@@ -342,8 +367,49 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
|
||||
kernel_info * kernel = &ctx.kernels->gemm;
|
||||
|
||||
const int64_t nc = ne00;
|
||||
const int64_t nr = ggml_nelements(src1);
|
||||
|
||||
const size_t block_rows = kernel->get_nr();
|
||||
const size_t kr = kernel->get_kr();
|
||||
|
||||
const size_t num_bytes_multiplier = sizeof(uint16_t);
|
||||
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int dr = (nr + nth - 1) / nth;
|
||||
const int ir0 = dr * ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int64_t i = ir0; i < ir1; ++i) {
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
int64_t row_idx = ((const int32_t *)src1->data)[i];
|
||||
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
|
||||
|
||||
float *out = (float *)((char *)dst->data + i * nb1);
|
||||
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public:
|
||||
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
||||
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
const size_t n = tensor->ne[1];
|
||||
const size_t k = tensor->ne[0];
|
||||
@@ -351,17 +417,12 @@ public:
|
||||
size_t kr = ctx.kernels->gemm.get_kr();
|
||||
size_t sr = ctx.kernels->gemm.get_sr();
|
||||
|
||||
#ifndef NDEBUG
|
||||
const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
||||
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
|
||||
#endif
|
||||
struct kai_rhs_pack_qs4cxs1s0_param params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
|
||||
|
||||
return 0;
|
||||
|
||||
GGML_UNUSED(data_size);
|
||||
}
|
||||
};
|
||||
@@ -375,8 +436,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
|
||||
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
||||
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
||||
@@ -418,18 +479,35 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
|
||||
const size_t n = tensor->ne[1];
|
||||
const size_t k = tensor->ne[0];
|
||||
const size_t nr = ctx.kernels->gemm.get_nr();
|
||||
const size_t kr = ctx.kernels->gemm.get_kr();
|
||||
|
||||
return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
namespace ggml::cpu::kleidiai {
|
||||
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
||||
if (op->op == GGML_OP_MUL_MAT &&
|
||||
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
|
||||
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
||||
op->src[0]->buffer &&
|
||||
(ggml_n_dims(op->src[0]) == 2) &&
|
||||
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
||||
if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type == GGML_TYPE_F32 &&
|
||||
if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
|
||||
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
|
||||
return true;
|
||||
}
|
||||
@@ -438,7 +516,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
}
|
||||
|
||||
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
||||
if (op->op == GGML_OP_MUL_MAT) {
|
||||
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
|
||||
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
||||
}
|
||||
@@ -469,7 +547,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
|
||||
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
||||
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
|
||||
/* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
|
||||
/* .is_host = */ nullptr,
|
||||
},
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4015,6 +4015,9 @@ static void ggml_compute_forward_rms_norm_f32(
|
||||
|
||||
const float scale = 1.0f/sqrtf(mean + eps);
|
||||
|
||||
// if you hit this, likely you got an inf somewhere earlier
|
||||
assert(scale > 0.0f);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
}
|
||||
}
|
||||
@@ -4643,9 +4646,11 @@ static void ggml_compute_forward_scale_f32(
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
// scale factor
|
||||
float v;
|
||||
memcpy(&v, dst->op_params, sizeof(float));
|
||||
float s; // scale factor
|
||||
float b; // bias
|
||||
|
||||
memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
@@ -4664,12 +4669,22 @@ static void ggml_compute_forward_scale_f32(
|
||||
|
||||
const size_t nb1 = dst->nb[1];
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
if (dst->data != src0->data) {
|
||||
// src0 is same shape as dst => same indices
|
||||
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
|
||||
if (b == 0.0f) {
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
if (dst->data != src0->data) {
|
||||
// src0 is same shape as dst => same indices
|
||||
// TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
|
||||
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
|
||||
}
|
||||
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
|
||||
}
|
||||
} else {
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
ggml_vec_mad1_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*nb1),
|
||||
(float *) ((char *) src0->data + i1*nb1),
|
||||
s, b);
|
||||
}
|
||||
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <cassert>
|
||||
#include <cstdlib> // for qsort
|
||||
#include <cstdio> // for GGML_ASSERT
|
||||
|
||||
#include "repack.h"
|
||||
|
||||
@@ -221,6 +221,9 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
|
||||
for (int i = np; i < n; ++i) {
|
||||
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
|
||||
// if you hit this, you are likely running outside the FP range
|
||||
assert(!isnan(sumf) && !isinf(sumf));
|
||||
#else
|
||||
for (int i = 0; i < n; ++i) {
|
||||
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
|
||||
@@ -351,6 +351,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
|
||||
#if defined(GGML_USE_ACCELERATE)
|
||||
vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
|
||||
#elif defined(GGML_SIMD)
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
// scalar ; TODO: Write SVE code
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = x[i]*s + b;
|
||||
}
|
||||
#else
|
||||
const int np = (n & ~(GGML_F32_STEP - 1));
|
||||
|
||||
GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
|
||||
GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);
|
||||
|
||||
GGML_F32_VEC ay[GGML_F32_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
||||
for (int j = 0; j < GGML_F32_ARR; j++) {
|
||||
ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
|
||||
ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
|
||||
|
||||
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = x[i]*s + b;
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = x[i]*s + b;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
|
||||
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||
#if defined(GGML_USE_ACCELERATE)
|
||||
|
||||
@@ -102,12 +102,12 @@ if (CUDAToolkit_FOUND)
|
||||
if (GGML_STATIC)
|
||||
if (WIN32)
|
||||
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
|
||||
else ()
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
|
||||
endif()
|
||||
else()
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_NO_VMM)
|
||||
|
||||
@@ -56,7 +56,7 @@
|
||||
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
||||
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
|
||||
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
|
||||
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
|
||||
#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
|
||||
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
|
||||
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
|
||||
|
||||
@@ -72,8 +72,9 @@
|
||||
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
||||
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
|
||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
|
||||
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
|
||||
|
||||
// Moore Threads
|
||||
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
|
||||
@@ -226,6 +227,10 @@ typedef float2 dfloat2;
|
||||
#define FP16_MMA_AVAILABLE
|
||||
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
|
||||
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)
|
||||
#define AMD_MFMA_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)
|
||||
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
#define NEW_MMA_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
@@ -288,6 +293,11 @@ static bool fp32_mma_hardware_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_CDNA(cc);
|
||||
}
|
||||
|
||||
// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later.
|
||||
static bool amd_mfma_available(const int cc) {
|
||||
return cc >= GGML_CUDA_CC_OFFSET_AMD && GGML_CUDA_CC_IS_CDNA3(cc);
|
||||
}
|
||||
|
||||
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
||||
static bool new_mma_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
||||
@@ -765,7 +775,7 @@ struct ggml_tensor_extra_gpu {
|
||||
};
|
||||
|
||||
|
||||
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
|
||||
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) || defined(GGML_MUSA_GRAPHS)
|
||||
#define USE_CUDA_GRAPH
|
||||
#endif
|
||||
|
||||
|
||||
@@ -6,24 +6,33 @@
|
||||
#define CUDA_Q8_0_NE_ALIGN 2048
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
|
||||
const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
|
||||
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03) {
|
||||
const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
|
||||
|
||||
if (i >= k) {
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t ib = i/qk; // block index
|
||||
const int64_t iqs = (i%qk)/qr; // quant index
|
||||
const int64_t iybs = i - i%qk; // y block start index
|
||||
const int64_t i01 = blockIdx.y;
|
||||
const int64_t i02 = blockIdx.z % ne02;
|
||||
const int64_t i03 = blockIdx.z / ne02;
|
||||
|
||||
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
|
||||
|
||||
const int64_t ib = ibx0 + i00/qk; // block index
|
||||
const int64_t iqs = (i00%qk)/qr; // quant index
|
||||
const int64_t iybs = i00 - i00%qk; // y block start index
|
||||
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
dfloat2 v;
|
||||
dequantize_kernel(vx, ib, iqs, v);
|
||||
|
||||
y[iybs + iqs + 0] = v.x;
|
||||
y[iybs + iqs + y_offset] = v.y;
|
||||
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
|
||||
y[iy0 + 0] = float(v.x);
|
||||
y[iy0 + y_offset] = float(v.y);
|
||||
}
|
||||
|
||||
template <bool need_check>
|
||||
@@ -457,9 +466,17 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
static void dequantize_block_cuda(const void * vx, dst_t * y,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
|
||||
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
|
||||
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
|
||||
(vx, y, ne00, ne01, ne02, s01, s02, s03);
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
|
||||
dequantize_block_cuda<qk, qr, dequantize_kernel, dst_t>(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream);
|
||||
}
|
||||
|
||||
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
|
||||
@@ -624,14 +641,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
case GGML_TYPE_Q4_1:
|
||||
return dequantize_row_q4_1_cuda;
|
||||
case GGML_TYPE_Q5_0:
|
||||
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
||||
return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
||||
case GGML_TYPE_Q5_1:
|
||||
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||
return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||
case GGML_TYPE_Q8_0:
|
||||
if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
|
||||
return dequantize_block_q8_0_f16_cuda;
|
||||
}
|
||||
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||
return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||
case GGML_TYPE_Q2_K:
|
||||
return dequantize_row_q2_K_cuda;
|
||||
case GGML_TYPE_Q3_K:
|
||||
@@ -676,11 +693,11 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||
case GGML_TYPE_Q4_1:
|
||||
return dequantize_row_q4_1_cuda;
|
||||
case GGML_TYPE_Q5_0:
|
||||
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
||||
return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
||||
case GGML_TYPE_Q5_1:
|
||||
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||
return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||
case GGML_TYPE_Q8_0:
|
||||
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||
return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||
case GGML_TYPE_Q2_K:
|
||||
return dequantize_row_q2_K_cuda;
|
||||
case GGML_TYPE_Q3_K:
|
||||
@@ -722,6 +739,16 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_cuda<float>;
|
||||
case GGML_TYPE_Q4_0:
|
||||
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
||||
case GGML_TYPE_Q4_1:
|
||||
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
|
||||
case GGML_TYPE_Q5_0:
|
||||
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
||||
case GGML_TYPE_Q5_1:
|
||||
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||
case GGML_TYPE_Q8_0:
|
||||
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||
case GGML_TYPE_BF16:
|
||||
return convert_unary_cuda<nv_bfloat16>;
|
||||
default:
|
||||
@@ -733,6 +760,16 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_cuda<float, nv_bfloat16>;
|
||||
case GGML_TYPE_Q4_0:
|
||||
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
||||
case GGML_TYPE_Q4_1:
|
||||
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
|
||||
case GGML_TYPE_Q5_0:
|
||||
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
||||
case GGML_TYPE_Q5_1:
|
||||
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||
case GGML_TYPE_Q8_0:
|
||||
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_unary_cuda<half, nv_bfloat16>;
|
||||
default:
|
||||
@@ -744,6 +781,16 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_F16:
|
||||
return convert_unary_cuda<half, float>;
|
||||
case GGML_TYPE_Q4_0:
|
||||
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
||||
case GGML_TYPE_Q4_1:
|
||||
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
|
||||
case GGML_TYPE_Q5_0:
|
||||
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
||||
case GGML_TYPE_Q5_1:
|
||||
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||
case GGML_TYPE_Q8_0:
|
||||
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||
case GGML_TYPE_BF16:
|
||||
return convert_unary_cuda<nv_bfloat16, float>;
|
||||
default:
|
||||
|
||||
225
ggml/src/ggml-cuda/cpy-utils.cuh
Normal file
225
ggml/src/ggml-cuda/cpy-utils.cuh
Normal file
@@ -0,0 +1,225 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml-common.h"
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
|
||||
if constexpr (std::is_same_v<src_t, dst_t>) {
|
||||
*dst = *src;
|
||||
} else {
|
||||
*dst = float(*src);
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
|
||||
if (x <= val[0]) return 0;
|
||||
if (x >= val[n-1]) return n-1;
|
||||
int ml = 0, mu = n-1;
|
||||
while (mu-ml > 1) {
|
||||
int mav = (ml+mu)/2;
|
||||
if (x < val[mav]) mu = mav; else ml = mav;
|
||||
}
|
||||
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||
}
|
||||
|
||||
static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {
|
||||
float amax = 0.0f;
|
||||
float vmax = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_0; ++j) {
|
||||
const float v = x[j];
|
||||
if (amax < fabsf(v)) {
|
||||
amax = fabsf(v);
|
||||
vmax = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = vmax / -8;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
y->d = d;
|
||||
|
||||
for (int j = 0; j < QK4_0/2; ++j) {
|
||||
const float x0 = x[0 + j]*id;
|
||||
const float x1 = x[QK4_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
|
||||
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
|
||||
|
||||
y->qs[j] = xi0;
|
||||
y->qs[j] |= xi1 << 4;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {
|
||||
float vmin = FLT_MAX;
|
||||
float vmax = -FLT_MAX;
|
||||
|
||||
for (int j = 0; j < QK4_1; ++j) {
|
||||
const float v = x[j];
|
||||
if (v < vmin) vmin = v;
|
||||
if (v > vmax) vmax = v;
|
||||
}
|
||||
|
||||
const float d = (vmax - vmin) / ((1 << 4) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
y->dm.x = d;
|
||||
y->dm.y = vmin;
|
||||
|
||||
for (int j = 0; j < QK4_1/2; ++j) {
|
||||
const float x0 = (x[0 + j] - vmin)*id;
|
||||
const float x1 = (x[QK4_1/2 + j] - vmin)*id;
|
||||
|
||||
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
|
||||
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
|
||||
|
||||
y->qs[j] = xi0;
|
||||
y->qs[j] |= xi1 << 4;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {
|
||||
float amax = 0.0f;
|
||||
float vmax = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK5_0; ++j) {
|
||||
const float v = x[j];
|
||||
if (amax < fabsf(v)) {
|
||||
amax = fabsf(v);
|
||||
vmax = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = vmax / -16;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
y->d = d;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_0/2; ++j) {
|
||||
const float x0 = x[0 + j]*id;
|
||||
const float x1 = x[QK5_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
|
||||
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
|
||||
|
||||
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
||||
}
|
||||
memcpy(y->qh, &qh, sizeof(qh));
|
||||
}
|
||||
|
||||
static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {
|
||||
float min = x[0];
|
||||
float max = x[0];
|
||||
|
||||
for (int j = 1; j < QK5_1; ++j) {
|
||||
const float v = x[j];
|
||||
min = v < min ? v : min;
|
||||
max = v > max ? v : max;
|
||||
}
|
||||
|
||||
const float d = (max - min) / 31;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
y->dm.x = d;
|
||||
y->dm.y = min;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_1/2; ++j) {
|
||||
const float x0 = (x[0 + j] - min)*id;
|
||||
const float x1 = (x[QK5_1/2 + j] - min)*id;
|
||||
|
||||
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
||||
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
||||
|
||||
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
||||
}
|
||||
memcpy(y->qh, &qh, sizeof(qh));
|
||||
}
|
||||
|
||||
static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
const float v = x[j];
|
||||
amax = fmaxf(amax, fabsf(v));
|
||||
}
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
y->d = d;
|
||||
|
||||
for (int j = 0; j < QK8_0; ++j) {
|
||||
const float x0 = x[j]*id;
|
||||
y->qs[j] = roundf(x0);
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
|
||||
float amax = 0.0f;
|
||||
float vmax = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_NL; ++j) {
|
||||
const float v = x[j];
|
||||
if (amax < fabsf(v)) {
|
||||
amax = fabsf(v);
|
||||
vmax = v;
|
||||
}
|
||||
}
|
||||
|
||||
float d = vmax / kvalues_iq4nl[0];
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int j = 0; j < QK4_NL/2; ++j) {
|
||||
const float x0 = x[0 + j]*id;
|
||||
const float x1 = x[QK4_NL/2 + j]*id;
|
||||
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
|
||||
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
|
||||
y->qs[j] = xi0 | (xi1 << 4);
|
||||
const float v0 = kvalues_iq4nl[xi0];
|
||||
const float v1 = kvalues_iq4nl[xi1];
|
||||
const float w0 = x[0 + j]*x[0 + j];
|
||||
const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];
|
||||
sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];
|
||||
sumq2 += w0*v0*v0 + w1*v1*v1;
|
||||
}
|
||||
|
||||
y->d = sumq2 > 0 ? sumqx/sumq2 : d;
|
||||
}
|
||||
|
||||
// Wrapper functions for cpy.cu compatibility
|
||||
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
||||
quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
|
||||
quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
|
||||
quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
|
||||
quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||
quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
||||
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
|
||||
convert_flt((const src_t *)cxi, (dst_t *)cdsti);
|
||||
}
|
||||
@@ -1,51 +1,17 @@
|
||||
#include "cpy.cuh"
|
||||
#include "dequantize.cuh"
|
||||
#ifdef GGML_USE_MUSA
|
||||
#include "cpy-utils.cuh"
|
||||
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
||||
#include "ggml-musa/mudnn.cuh"
|
||||
#endif // GGML_USE_MUSA
|
||||
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
|
||||
|
||||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||
|
||||
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
float * dsti = (float *) cdsti;
|
||||
|
||||
*dsti = *xi;
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti;
|
||||
|
||||
*dsti = *xi;
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
half * dsti = (half *) cdsti;
|
||||
|
||||
*dsti = __float2half(*xi);
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
|
||||
const half * xi = (const half *) cxi;
|
||||
half * dsti = (half *) cdsti;
|
||||
|
||||
*dsti = *xi;
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
||||
const half * xi = (const half *) cxi;
|
||||
float * dsti = (float *) cdsti;
|
||||
|
||||
*dsti = *xi;
|
||||
}
|
||||
|
||||
template <cpy_kernel_t cpy_1>
|
||||
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
@@ -71,29 +37,6 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in
|
||||
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_q8_0 * dsti = (block_q8_0 *) cdsti;
|
||||
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
const float v = xi[j];
|
||||
amax = fmaxf(amax, fabsf(v));
|
||||
}
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dsti->d = d;
|
||||
|
||||
for (int j = 0; j < QK8_0; ++j) {
|
||||
const float x0 = xi[j]*id;
|
||||
|
||||
dsti->qs[j] = roundf(x0);
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
||||
float * cdstf = (float *)(cdsti);
|
||||
|
||||
@@ -106,139 +49,6 @@ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_q4_0 * dsti = (block_q4_0 *) cdsti;
|
||||
|
||||
float amax = 0.0f;
|
||||
float vmax = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_0; ++j) {
|
||||
const float v = xi[j];
|
||||
if (amax < fabsf(v)) {
|
||||
amax = fabsf(v);
|
||||
vmax = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = vmax / -8;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dsti->d = d;
|
||||
|
||||
for (int j = 0; j < QK4_0/2; ++j) {
|
||||
const float x0 = xi[0 + j]*id;
|
||||
const float x1 = xi[QK4_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
|
||||
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
|
||||
|
||||
dsti->qs[j] = xi0;
|
||||
dsti->qs[j] |= xi1 << 4;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_q4_1 * dsti = (block_q4_1 *) cdsti;
|
||||
|
||||
float vmin = FLT_MAX;
|
||||
float vmax = -FLT_MAX;
|
||||
|
||||
for (int j = 0; j < QK4_1; ++j) {
|
||||
const float v = xi[j];
|
||||
|
||||
if (v < vmin) vmin = v;
|
||||
if (v > vmax) vmax = v;
|
||||
}
|
||||
|
||||
const float d = (vmax - vmin) / ((1 << 4) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dsti->dm.x = d;
|
||||
dsti->dm.y = vmin;
|
||||
|
||||
for (int j = 0; j < QK4_1/2; ++j) {
|
||||
const float x0 = (xi[0 + j] - vmin)*id;
|
||||
const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
|
||||
|
||||
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
|
||||
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
|
||||
|
||||
dsti->qs[j] = xi0;
|
||||
dsti->qs[j] |= xi1 << 4;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_q5_0 * dsti = (block_q5_0 *) cdsti;
|
||||
|
||||
float amax = 0.0f;
|
||||
float vmax = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK5_0; ++j) {
|
||||
const float v = xi[j];
|
||||
if (amax < fabsf(v)) {
|
||||
amax = fabsf(v);
|
||||
vmax = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = vmax / -16;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dsti->d = d;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_0/2; ++j) {
|
||||
const float x0 = xi[0 + j]*id;
|
||||
const float x1 = xi[QK5_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
|
||||
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
|
||||
|
||||
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
||||
}
|
||||
memcpy(dsti->qh, &qh, sizeof(qh));
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_q5_1 * dsti = (block_q5_1 *) cdsti;
|
||||
|
||||
float min = xi[0];
|
||||
float max = xi[0];
|
||||
|
||||
for (int j = 1; j < QK5_1; ++j) {
|
||||
const float v = xi[j];
|
||||
min = v < min ? v : min;
|
||||
max = v > max ? v : max;
|
||||
}
|
||||
|
||||
const float d = (max - min) / 31;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dsti->dm.x = d;
|
||||
dsti->dm.y = min;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_1/2; ++j) {
|
||||
const float x0 = (xi[0 + j] - min)*id;
|
||||
const float x1 = (xi[QK5_1/2 + j] - min)*id;
|
||||
|
||||
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
||||
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
||||
|
||||
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
||||
}
|
||||
memcpy(dsti->qh, &qh, sizeof(qh));
|
||||
}
|
||||
|
||||
template<dequantize_kernel_t dequant, int qk>
|
||||
static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
||||
float * cdstf = (float *)(cdsti);
|
||||
@@ -252,53 +62,6 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
|
||||
if (x <= val[0]) return 0;
|
||||
if (x >= val[n-1]) return n-1;
|
||||
int ml = 0, mu = n-1;
|
||||
while (mu-ml > 1) {
|
||||
int mav = (ml+mu)/2;
|
||||
if (x < val[mav]) mu = mav; else ml = mav;
|
||||
}
|
||||
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
|
||||
|
||||
float amax = 0.0f;
|
||||
float vmax = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_NL; ++j) {
|
||||
const float v = xi[j];
|
||||
if (amax < fabsf(v)) {
|
||||
amax = fabsf(v);
|
||||
vmax = v;
|
||||
}
|
||||
}
|
||||
|
||||
float d = vmax / kvalues_iq4nl[0];
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int j = 0; j < QK4_NL/2; ++j) {
|
||||
const float x0 = xi[0 + j]*id;
|
||||
const float x1 = xi[QK4_NL/2 + j]*id;
|
||||
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
|
||||
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
|
||||
dsti->qs[j] = xi0 | (xi1 << 4);
|
||||
const float v0 = kvalues_iq4nl[xi0];
|
||||
const float v1 = kvalues_iq4nl[xi1];
|
||||
const float w0 = xi[0 + j]*xi[0 + j];
|
||||
const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
|
||||
sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
|
||||
sumq2 += w0*v0*v0 + w1*v1*v1;
|
||||
}
|
||||
|
||||
dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
|
||||
}
|
||||
|
||||
template <cpy_kernel_t cpy_blck, int qk>
|
||||
static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
@@ -358,7 +121,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
|
||||
// Copy destination pointers to GPU to be available when pointer indirection is in use
|
||||
|
||||
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
||||
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
if (cuda_graph->dest_ptrs_d != nullptr) {
|
||||
@@ -376,43 +139,14 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_cpy_f16_f32_cuda(
|
||||
template<typename src_t, typename dst_t>
|
||||
static void ggml_cpy_flt_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_bf16_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_f16_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
@@ -544,16 +278,6 @@ static void ggml_cpy_f32_iq4_nl_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f16_f16_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||
@@ -590,7 +314,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
|
||||
char ** dest_ptrs_d = nullptr;
|
||||
int graph_cpynode_index = -1;
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
||||
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
||||
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
|
||||
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
|
||||
@@ -600,20 +324,20 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
#endif
|
||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
#ifdef GGML_USE_MUSA
|
||||
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
||||
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
|
||||
CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
|
||||
} else
|
||||
#endif // GGML_USE_MUSA
|
||||
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
|
||||
{
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||
@@ -640,14 +364,22 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
}
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
||||
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
||||
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
|
||||
}
|
||||
@@ -667,11 +399,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
return nullptr;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
|
||||
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_bf16>;
|
||||
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||
return (void*) cpy_flt<cpy_1_flt<float, half>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||
@@ -695,9 +427,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||
return (void*) cpy_flt<cpy_1_flt<half, half>>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
|
||||
return (void*) cpy_flt<cpy_1_flt<half, float>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
|
||||
@@ -23,31 +23,13 @@ typedef void (* fattn_kernel_t)(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3);
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33);
|
||||
|
||||
typedef half (*vec_dot_KQ_f16_t)(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
||||
@@ -521,7 +503,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
||||
template<int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
@@ -535,8 +517,8 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
|
||||
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
@@ -545,14 +527,15 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
return;
|
||||
}
|
||||
|
||||
const int channel = kbc0 / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
||||
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
||||
|
||||
if (jt*ncols1 + j >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
|
||||
// Load the partial result that needs a fixup:
|
||||
float dst_val = 0.0f;
|
||||
@@ -571,7 +554,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
@@ -617,16 +600,31 @@ static __global__ void flash_attn_combine_results(
|
||||
const float2 * __restrict__ VKQ_meta,
|
||||
float * __restrict__ dst,
|
||||
const int parallel_blocks) {
|
||||
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
|
||||
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
|
||||
dst += D * gridDim.z*blockIdx.x;
|
||||
// Dimension 0: threadIdx.x
|
||||
// Dimension 1: blockIdx.x
|
||||
// Dimension 2: blockIdx.y
|
||||
// Dimension 3: blockIdx.z
|
||||
// Memory layout is permuted with [0, 2, 1, 3]
|
||||
|
||||
const int ne01 = gridDim.x;
|
||||
const int ne02 = gridDim.y;
|
||||
|
||||
const int col = blockIdx.x;
|
||||
const int head = blockIdx.y;
|
||||
const int sequence = blockIdx.z;
|
||||
|
||||
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
|
||||
|
||||
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
|
||||
VKQ_meta += j_dst_unrolled * parallel_blocks;
|
||||
dst += j_dst_unrolled * D;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
__builtin_assume(tid < D);
|
||||
|
||||
extern __shared__ float2 meta[];
|
||||
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
||||
((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
|
||||
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -644,11 +642,11 @@ static __global__ void flash_attn_combine_results(
|
||||
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
|
||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||
}
|
||||
|
||||
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
dst[tid] = VKQ_numerator / VKQ_denominator;
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
@@ -705,8 +703,6 @@ void launch_fattn(
|
||||
|
||||
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||
|
||||
GGML_ASSERT(Q->ne[3] == 1);
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int id = ggml_cuda_get_device();
|
||||
@@ -729,33 +725,58 @@ void launch_fattn(
|
||||
size_t nb23 = V ? V->nb[3] : nb13;
|
||||
|
||||
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(K));
|
||||
K_f16.alloc(ggml_nelements(K));
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
||||
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
||||
K_data = (char *) K_f16.ptr;
|
||||
|
||||
const size_t bs = ggml_blck_size(K->type);
|
||||
const size_t ts = ggml_type_size(K->type);
|
||||
|
||||
nb11 = nb11*bs*sizeof(half)/ts;
|
||||
nb12 = nb12*bs*sizeof(half)/ts;
|
||||
nb13 = nb13*bs*sizeof(half)/ts;
|
||||
K_f16.alloc(ggml_nelements(K));
|
||||
if (ggml_is_contiguously_allocated(K)) {
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
||||
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
||||
|
||||
nb11 = nb11*bs*sizeof(half)/ts;
|
||||
nb12 = nb12*bs*sizeof(half)/ts;
|
||||
nb13 = nb13*bs*sizeof(half)/ts;
|
||||
} else {
|
||||
GGML_ASSERT(K->nb[0] == ts);
|
||||
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
|
||||
const int64_t s01 = nb11 / ts;
|
||||
const int64_t s02 = nb12 / ts;
|
||||
const int64_t s03 = nb13 / ts;
|
||||
to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
|
||||
|
||||
nb11 = K->ne[0] * sizeof(half);
|
||||
nb12 = K->ne[1] * nb11;
|
||||
nb13 = K->ne[2] * nb12;
|
||||
}
|
||||
K_data = (char *) K_f16.ptr;
|
||||
}
|
||||
|
||||
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(V));
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_f16.ptr;
|
||||
|
||||
const size_t bs = ggml_blck_size(V->type);
|
||||
const size_t ts = ggml_type_size(V->type);
|
||||
|
||||
nb21 = nb21*bs*sizeof(half)/ts;
|
||||
nb22 = nb22*bs*sizeof(half)/ts;
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
if (ggml_is_contiguously_allocated(V)) {
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_f16.ptr;
|
||||
|
||||
nb21 = nb21*bs*sizeof(half)/ts;
|
||||
nb22 = nb22*bs*sizeof(half)/ts;
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
} else {
|
||||
GGML_ASSERT(V->nb[0] == ts);
|
||||
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
|
||||
const int64_t s01 = nb21 / ts;
|
||||
const int64_t s02 = nb22 / ts;
|
||||
const int64_t s03 = nb23 / ts;
|
||||
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
||||
|
||||
nb21 = V->ne[0] * sizeof(half);
|
||||
nb22 = V->ne[1] * nb21;
|
||||
nb23 = V->ne[2] * nb22;
|
||||
}
|
||||
V_data = (char *) V_f16.ptr;
|
||||
}
|
||||
|
||||
int parallel_blocks = 1;
|
||||
@@ -851,14 +872,11 @@ void launch_fattn(
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
|
||||
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
nb11, nb12, nb13,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
||||
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
@@ -869,11 +887,11 @@ void launch_fattn(
|
||||
|
||||
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
|
||||
}
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
|
||||
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
|
||||
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
||||
|
||||
flash_attn_combine_results<DV>
|
||||
|
||||
@@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const int stride_K,
|
||||
const int stride_V,
|
||||
const int stride_mask,
|
||||
const int jt,
|
||||
half2 * const __restrict__ tile_Q,
|
||||
half2 * const __restrict__ tile_K,
|
||||
half2 * const __restrict__ tile_V,
|
||||
@@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
cp_async_wait_all();
|
||||
__syncthreads();
|
||||
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
||||
(V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
|
||||
} else {
|
||||
constexpr bool use_cp_async = nstages == 1;
|
||||
if (ncols2 > 1 || mask_h2) {
|
||||
@@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
if (nstages <= 1) {
|
||||
constexpr bool use_cp_async = nstages == 1;
|
||||
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
||||
(K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
||||
if (use_cp_async) {
|
||||
cp_async_wait_all();
|
||||
}
|
||||
@@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
||||
}
|
||||
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
||||
(K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
if (nstages <= 1 && i0_start < reusable_cutoff) {
|
||||
constexpr bool use_cp_async = nstages == 1;
|
||||
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
||||
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
||||
if (use_cp_async) {
|
||||
cp_async_wait_all();
|
||||
}
|
||||
@@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
||||
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
|
||||
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
||||
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
||||
GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K);
|
||||
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
|
||||
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
|
||||
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
|
||||
@@ -920,7 +918,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
||||
}
|
||||
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
||||
(K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
||||
}
|
||||
|
||||
// Iterate over ne11 == previous tokens:
|
||||
@@ -928,13 +926,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr bool last_iter = false;
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
||||
}
|
||||
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
||||
constexpr bool last_iter = true;
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
||||
}
|
||||
|
||||
// With multi-stage loading there is no __syncthreads at the end of the iter,
|
||||
@@ -1214,31 +1212,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
@@ -1274,8 +1254,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
|
||||
|
||||
// kbc == k block continuous, current index in continuous ijk space.
|
||||
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
|
||||
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
||||
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
||||
@@ -1285,18 +1265,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
int kb0_start = kbc % iter_k;
|
||||
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
||||
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
||||
const int channel = kbc / (iter_k*iter_j);
|
||||
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
||||
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
||||
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
@@ -1325,18 +1306,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
return;
|
||||
}
|
||||
|
||||
const int channel = kbc / (iter_k*iter_j);
|
||||
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
||||
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
||||
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
@@ -1348,15 +1330,16 @@ static __global__ void flash_attn_ext_f16(
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
||||
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
||||
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
@@ -21,31 +21,13 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
@@ -62,15 +44,17 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
@@ -123,7 +107,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
||||
KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
|
||||
KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,6 +239,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float2 * dst2 = (float2 *) dst;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||
@@ -266,21 +252,21 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
|
||||
kqsum_j = warp_reduce_sum((float)kqsum_j);
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||
const int i0 = i00 + 2*threadIdx.x;
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
|
||||
const int i0 = i00 + threadIdx.x;
|
||||
|
||||
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= __half2half2(kqsum_j);
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
|
||||
dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
@@ -290,12 +276,11 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
GGML_UNUSED(nb23);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
|
||||
@@ -21,31 +21,13 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
@@ -55,17 +37,16 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
#endif // FP16_MMA_AVAILABLE
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
||||
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
@@ -74,15 +55,17 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
|
||||
@@ -131,7 +114,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
|
||||
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
|
||||
const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
|
||||
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
|
||||
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
|
||||
}
|
||||
@@ -227,8 +210,9 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
|
||||
KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
|
||||
const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
|
||||
KV_tmp2[k*(D/2) + i].x = __low2float(tmp);
|
||||
KV_tmp2[k*(D/2) + i].y = __high2float(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,6 +249,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float2 * dst2 = (float2 *) dst;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||
@@ -276,37 +262,36 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
||||
kqsum_j = warp_reduce_sum(kqsum_j);
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||
const int i0 = i00 + 2*threadIdx.x;
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
|
||||
const int i0 = i00 + threadIdx.x;
|
||||
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
|
||||
if (gridDim.y == 1) {
|
||||
dst_val.x /= kqsum_j;
|
||||
dst_val.y /= kqsum_j;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
|
||||
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
||||
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
||||
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
@@ -18,31 +18,13 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
@@ -65,14 +47,16 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb02* blockIdx.z + nb01*ic0;
|
||||
K += nb12*(blockIdx.z / gqa_ratio);
|
||||
V += nb22*(blockIdx.z / gqa_ratio);
|
||||
Q += nb03*sequence + nb02* head + nb01*ic0;
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
@@ -187,13 +171,19 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
||||
|
||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
||||
K += blockIdx.y*D * nb11;
|
||||
V += blockIdx.y*D * nb21;
|
||||
maskh += blockIdx.y*D;
|
||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
|
||||
// Increment pointers after each loop:
|
||||
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
|
||||
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
if (mask) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];
|
||||
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -240,7 +230,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
||||
half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum((float)sum);
|
||||
|
||||
if (use_logit_softcap) {
|
||||
@@ -296,8 +286,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
}
|
||||
|
||||
half2 V_k;
|
||||
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
|
||||
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
|
||||
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
|
||||
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
|
||||
@@ -330,26 +320,24 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
||||
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
|
||||
@@ -18,31 +18,13 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
@@ -53,12 +35,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
GGML_UNUSED(nb23);
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
@@ -77,14 +58,16 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb02* blockIdx.z + nb01*ic0;
|
||||
K += nb12*(blockIdx.z / gqa_ratio);
|
||||
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
|
||||
Q += nb03*sequence + nb02* head + nb01*ic0;
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
constexpr int nwarps = D / WARP_SIZE;
|
||||
@@ -194,13 +177,19 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
|
||||
float VKQ[ncols] = {0.0f};
|
||||
|
||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
||||
K += blockIdx.y*D * nb11;
|
||||
V += blockIdx.y*D * nb21;
|
||||
maskh += blockIdx.y*D;
|
||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
|
||||
// Increment pointers after each loop:
|
||||
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
|
||||
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
if (mask) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]);
|
||||
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -242,7 +231,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
||||
float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (use_logit_softcap) {
|
||||
@@ -293,7 +282,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
break;
|
||||
}
|
||||
|
||||
const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
|
||||
const float V_ki = dequantize_1_v(V + k*nb21, tid);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j] += V_ki*KQ[j*D + k];
|
||||
@@ -326,12 +315,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
@@ -340,12 +328,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
||||
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
@@ -37,31 +37,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
@@ -95,17 +77,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
||||
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const half2 * mask2 = (const half2 *) maskh;
|
||||
|
||||
const int stride_Q = nb01 / sizeof(float);
|
||||
const int stride_KV = nb11 / sizeof(half);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
const half2 slope2 = make_half2(slopef, slopef);
|
||||
|
||||
@@ -193,7 +177,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
||||
frag_a_K K_a;
|
||||
wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
||||
@@ -340,7 +324,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
|
||||
frag_a_V v_a;
|
||||
wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||
@@ -400,7 +384,6 @@ static __global__ void flash_attn_ext_f16(
|
||||
if (ic0 + j_VKQ >= ne01) {
|
||||
return;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
|
||||
float KQ_rowsum_j;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
@@ -409,6 +392,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
||||
}
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
@@ -419,7 +404,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= KQ_rowsum_j;
|
||||
}
|
||||
dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
|
||||
dst[j_dst_unrolled*D + i] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y == 1 || threadIdx.x != 0) {
|
||||
@@ -433,7 +418,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
||||
}
|
||||
dst_meta_val.y = KQ_rowsum_j;
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
|
||||
dst_meta[j_dst_unrolled] = dst_meta_val;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
@@ -442,10 +427,10 @@ static __global__ void flash_attn_ext_f16(
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
|
||||
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
||||
}
|
||||
|
||||
@@ -280,22 +280,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
if (fp16_mma_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
return;
|
||||
}
|
||||
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
return;
|
||||
}
|
||||
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
|
||||
if (!fast_fp16_available(cc)) {
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
|
||||
@@ -43,6 +43,7 @@
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml-cuda/set-rows.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -54,6 +55,7 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <float.h>
|
||||
#include <initializer_list>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
@@ -2230,6 +2232,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
ggml_cuda_op_get_rows_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
ggml_cuda_op_set_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
ggml_cuda_dup(ctx, dst);
|
||||
break;
|
||||
@@ -2299,6 +2304,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_EXP:
|
||||
ggml_cuda_op_exp(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_ELU:
|
||||
ggml_cuda_op_elu(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -2583,6 +2591,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
||||
|
||||
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
||||
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
@@ -2604,9 +2615,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
|
||||
// disable CUDA graphs for batch size > 1 for now.
|
||||
// Changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true)) {
|
||||
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
||||
// by means of matching node names. See
|
||||
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
||||
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
|
||||
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
@@ -2752,6 +2766,39 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
}
|
||||
#endif
|
||||
|
||||
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
|
||||
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
|
||||
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
|
||||
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
|
||||
|
||||
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
|
||||
|
||||
//rms norm only supports F32
|
||||
if (mul->src[0]->type != GGML_TYPE_F32 ||
|
||||
mul->src[1]->type != GGML_TYPE_F32 ||
|
||||
mul->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
//if rms norm is the B operand, then we don't handle broadcast
|
||||
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
//rms_norm kernel assumes contigous rows
|
||||
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
|
||||
// flag used to determine whether it is an integrated_gpu
|
||||
@@ -2761,6 +2808,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
||||
// With the use of CUDA graphs, the execution will be performed by the graph launch.
|
||||
if (!use_cuda_graph || cuda_graph_update_required) {
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
@@ -2768,6 +2816,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
continue;
|
||||
}
|
||||
|
||||
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
||||
if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
#ifndef NDEBUG
|
||||
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
@@ -3112,6 +3166,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
default:
|
||||
return false;
|
||||
@@ -3216,17 +3271,21 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
{
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
|
||||
op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
|
||||
op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 &&
|
||||
op->src[1]->type == GGML_TYPE_I64;
|
||||
} break;
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
ggml_type src1_type = op->src[1]->type;
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
||||
if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) &&
|
||||
(src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
|
||||
@@ -3262,12 +3321,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
|
||||
return true;
|
||||
}
|
||||
@@ -3335,8 +3388,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_SSM_SCAN: {
|
||||
if (op->src[3]->ne[0] == 1) {
|
||||
// Mamba2
|
||||
// (kernel only supports d_state == 128 && d_head % 16 == 0)
|
||||
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
|
||||
// (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
|
||||
return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
|
||||
} else {
|
||||
// Mamba
|
||||
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
|
||||
@@ -3348,7 +3401,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return op->src[0]->ne[1] % 128 == 0;
|
||||
}
|
||||
case GGML_OP_CONT:
|
||||
return op->src[0]->type != GGML_TYPE_BF16;
|
||||
return true;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
return true;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
@@ -3398,12 +3451,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (op->src[0]->ne[0] == 192) {
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
|
||||
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
@@ -3416,6 +3463,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
if (op->src[3] && op->src[3]->ne[2] != 1) {
|
||||
return false;
|
||||
}
|
||||
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
|
||||
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ static __global__ void im2col_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t ksize = OW * (KH > 1 ? KW : 1);
|
||||
const int64_t ksize = OW * KH;
|
||||
const int64_t kx = i / ksize;
|
||||
const int64_t kd = kx * ksize;
|
||||
const int64_t ky = (i - kd) / OW;
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
||||
// All matrix tiles have ne physical 32 bit elements per warp.
|
||||
//
|
||||
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
||||
// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
||||
// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
@@ -66,7 +67,44 @@ namespace ggml_cuda_mma {
|
||||
struct tile {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
static constexpr int ne = I * J / 64;
|
||||
T x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
||||
return threadIdx.x % 16;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return threadIdx.x % 16;
|
||||
} else if constexpr (I == 32 && J == 4) {
|
||||
return threadIdx.x % 32;
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return 4 * (threadIdx.x / 16) + l;
|
||||
} else if constexpr (I == 32 && J == 32) {
|
||||
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
||||
return (2 * ((threadIdx.x / 16) % 2) + l);
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return 2 * (threadIdx.x / 16) + l;
|
||||
} else if constexpr (I == 32 && J == 4) {
|
||||
return 2 * (threadIdx.x / 32) + l;
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return threadIdx.x % 16;
|
||||
} else if constexpr (I == 32 && J == 32) {
|
||||
return threadIdx.x % 32;
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr int ne = I * J / 32;
|
||||
T x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
@@ -94,6 +132,7 @@ namespace ggml_cuda_mma {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
}
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
@@ -148,10 +187,23 @@ namespace ggml_cuda_mma {
|
||||
|
||||
template <int I, int J, typename T>
|
||||
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
||||
#pragma unroll
|
||||
for (int l = 0; l < t.ne; ++l) {
|
||||
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
||||
}
|
||||
} else {
|
||||
int64_t * xi = (int64_t *) t.x;
|
||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
||||
xi[0] = xs[0];
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int l = 0; l < t.ne; ++l) {
|
||||
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
||||
}
|
||||
#endif // defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -186,7 +238,7 @@ namespace ggml_cuda_mma {
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
#if defined(NEW_MMA_AVAILABLE)
|
||||
int * xi = (int * ) t.x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
||||
@@ -393,4 +445,60 @@ namespace ggml_cuda_mma {
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
||||
int32x4_t * acc = (int32x4_t *) D.x;
|
||||
#if defined(CDNA3)
|
||||
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
|
||||
((int64_t *) B.x)[0],
|
||||
acc[0],
|
||||
0, 0, 0);
|
||||
#elif defined(CDNA2) || defined(CDNA)
|
||||
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
|
||||
B.x[0],
|
||||
acc[0],
|
||||
0, 0, 0);
|
||||
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
|
||||
B.x[1],
|
||||
acc[0],
|
||||
0, 0, 0);
|
||||
#endif // defined(CDNA3)
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
|
||||
int32x16_t * acc = (int32x16_t *) D.x;
|
||||
#if defined(CDNA3)
|
||||
acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
|
||||
((int64_t *) B.x)[0],
|
||||
acc[0],
|
||||
0, 0, 0);
|
||||
#elif defined(CDNA2) || defined(CDNA)
|
||||
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
|
||||
B.x[0],
|
||||
acc[0],
|
||||
0, 0, 0);
|
||||
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
|
||||
B.x[1],
|
||||
acc[0],
|
||||
0, 0, 0);
|
||||
#endif // defined(CDNA3)
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,7 +109,8 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s03 = src0->nb[3] / ts_src0;
|
||||
const int64_t s3 = dst->nb[3] / ts_dst;
|
||||
|
||||
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
|
||||
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|
||||
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)));
|
||||
|
||||
if (!ids) {
|
||||
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
||||
@@ -250,8 +251,9 @@ void ggml_cuda_op_mul_mat_q(
|
||||
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
|
||||
// Also its fixup needs to allocate a temporary buffer in the memory pool.
|
||||
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
|
||||
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
|
||||
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
|
||||
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|
||||
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)))
|
||||
&& src1_ncols == ne11;
|
||||
const mmq_args args = {
|
||||
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
|
||||
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
|
||||
@@ -304,7 +306,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (new_mma_available(cc)) {
|
||||
if (new_mma_available(cc) || amd_mfma_available(cc)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -104,10 +104,12 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
template <int block_size, bool do_multiply = false>
|
||||
static __global__ void rms_norm_f32(
|
||||
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps) {
|
||||
const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
|
||||
const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
|
||||
const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
|
||||
const int nrows = gridDim.x;
|
||||
const int nchannels = gridDim.y;
|
||||
|
||||
@@ -119,6 +121,13 @@ static __global__ void rms_norm_f32(
|
||||
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
|
||||
if constexpr (do_multiply) {
|
||||
const int mul_row = row % mul_nrows;
|
||||
const int mul_channel = channel % mul_nchannels;
|
||||
const int mul_sample = sample % mul_nsamples;
|
||||
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
|
||||
}
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
@@ -145,7 +154,12 @@ static __global__ void rms_norm_f32(
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = scale * x[col];
|
||||
if constexpr (do_multiply) {
|
||||
const int mul_col = col % mul_ncols;
|
||||
dst[col] = scale * x[col] * mul[mul_col];
|
||||
} else {
|
||||
dst[col] = scale * x[col];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,10 +324,30 @@ static void rms_norm_f32_cuda(
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
}
|
||||
}
|
||||
|
||||
static void rms_norm_mul_f32_cuda(
|
||||
const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
|
||||
const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
|
||||
const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
|
||||
const float eps, cudaStream_t stream) {
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
if (mul == nullptr) {
|
||||
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
|
||||
return;
|
||||
}
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -407,6 +441,59 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
|
||||
const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
|
||||
float eps = 0.0f;
|
||||
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
const float * src0_d = (const float *) rms_norm_src->data;
|
||||
const float * mul_d = nullptr;
|
||||
const ggml_tensor * mul_src = nullptr;
|
||||
|
||||
if (mul_tensor->src[0] == dst) {
|
||||
mul_d = (float *) mul_tensor->src[1]->data;
|
||||
mul_src = mul_tensor->src[1];
|
||||
} else if(mul_tensor->src[1] == dst) {
|
||||
mul_d = (float *) mul_tensor->src[0]->data;
|
||||
mul_src = mul_tensor->src[0];
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
float * dst_d = (float *) mul_tensor->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
const int64_t ne00 = rms_norm_src->ne[0];
|
||||
const int64_t ne01 = rms_norm_src->ne[1];
|
||||
const int64_t ne02 = rms_norm_src->ne[2];
|
||||
const int64_t ne03 = rms_norm_src->ne[3];
|
||||
|
||||
const size_t ts0 = ggml_type_size(rms_norm_src->type);
|
||||
GGML_ASSERT(rms_norm_src->nb[0] == ts0);
|
||||
const int64_t s01 = rms_norm_src->nb[1] / ts0;
|
||||
const int64_t s02 = rms_norm_src->nb[2] / ts0;
|
||||
const int64_t s03 = rms_norm_src->nb[3] / ts0;
|
||||
|
||||
const size_t ts_mul = ggml_type_size(mul_src->type);
|
||||
GGML_ASSERT(mul_src->nb[0] == ts_mul);
|
||||
const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
|
||||
const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
|
||||
const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
|
||||
|
||||
const int mul_ncols = mul_src->ne[0];
|
||||
const int mul_nrows = mul_src->ne[1];
|
||||
const int mul_nchannels = mul_src->ne[2];
|
||||
const int mul_nsamples = mul_src->ne[3];
|
||||
|
||||
rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * grad = dst->src[0]; // gradients
|
||||
const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
|
||||
|
||||
@@ -6,6 +6,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
|
||||
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
|
||||
|
||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
#include "scale.cuh"
|
||||
|
||||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
|
||||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[i] = scale * x[i];
|
||||
dst[i] = scale * x[i] + bias;
|
||||
}
|
||||
|
||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
||||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
||||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
@@ -25,7 +25,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
float scale;
|
||||
memcpy(&scale, dst->op_params, sizeof(float));
|
||||
float bias;
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
|
||||
scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
275
ggml/src/ggml-cuda/set-rows.cu
Normal file
275
ggml/src/ggml-cuda/set-rows.cu
Normal file
@@ -0,0 +1,275 @@
|
||||
#include "set-rows.cuh"
|
||||
#include "cpy-utils.cuh"
|
||||
|
||||
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
__device__ __forceinline__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
|
||||
convert_flt(src_f, dst_f);
|
||||
}
|
||||
|
||||
// Generic quantized set_rows kernel template
|
||||
template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
|
||||
static __global__ void k_set_rows_quant(
|
||||
const float * __restrict__ src0, const int64_t * __restrict__ src1, block_type * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s10, const int64_t s11, const int64_t s12,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
|
||||
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
||||
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
|
||||
|
||||
if (i >= ne_total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i_base = i * qk;
|
||||
const int64_t i03 = i_base / (ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
|
||||
const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
|
||||
|
||||
const int64_t i12 = i03 % ne12;
|
||||
const int64_t i11 = i02 % ne11;
|
||||
const int64_t i10 = i01;
|
||||
|
||||
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||
|
||||
const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
|
||||
block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type);
|
||||
|
||||
const float * src_block = src0_row + i00;
|
||||
block_type * dst_block = dst_row_ptr + i00 / qk;
|
||||
|
||||
quantize_func(src_block, dst_block);
|
||||
|
||||
GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne13);
|
||||
}
|
||||
|
||||
// Template dispatch function for quantized set_rows
|
||||
template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
|
||||
static void set_rows_cuda_quant(
|
||||
const float * src0_d, const int64_t * src1_d, block_type * dst_d,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t nb10, const size_t nb11, const size_t nb12,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne00 % qk == 0);
|
||||
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
|
||||
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
|
||||
const dim3 grid_size(num_blocks);
|
||||
|
||||
const int64_t s01 = nb01/sizeof(float);
|
||||
const int64_t s02 = nb02/sizeof(float);
|
||||
const int64_t s03 = nb03/sizeof(float);
|
||||
const int64_t s10 = nb10/sizeof(int64_t);
|
||||
const int64_t s11 = nb11/sizeof(int64_t);
|
||||
const int64_t s12 = nb12/sizeof(int64_t);
|
||||
const int64_t s1 = nb1;
|
||||
const int64_t s2 = nb2;
|
||||
const int64_t s3 = nb3;
|
||||
|
||||
if (ne_total > 0) {
|
||||
k_set_rows_quant<block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
s01, s02, s03,
|
||||
s10, s11, s12,
|
||||
s1, s2, s3);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __global__ void k_set_rows(
|
||||
const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s10, const int64_t s11, const int64_t s12,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
|
||||
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
||||
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
|
||||
|
||||
if (i >= ne_total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i03 = i / (ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
|
||||
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
|
||||
|
||||
const int64_t i12 = i03 % ne12;
|
||||
const int64_t i11 = i02 % ne11;
|
||||
const int64_t i10 = i01;
|
||||
|
||||
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||
|
||||
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
|
||||
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
|
||||
|
||||
const src_t* src_elem = src0_row + i00;
|
||||
dst_t* dst_elem = dst_row_ptr + i00;
|
||||
set_rows_1(src_elem, dst_elem);
|
||||
|
||||
GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne13);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static void set_rows_cuda(
|
||||
const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t nb10, const size_t nb11, const size_t nb12,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
|
||||
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
|
||||
const dim3 grid_size(num_blocks);
|
||||
|
||||
|
||||
const int64_t s01 = nb01/sizeof(src_t);
|
||||
const int64_t s02 = nb02/sizeof(src_t);
|
||||
const int64_t s03 = nb03/sizeof(src_t);
|
||||
const int64_t s10 = nb10/sizeof(int64_t);
|
||||
const int64_t s11 = nb11/sizeof(int64_t);
|
||||
const int64_t s12 = nb12/sizeof(int64_t);
|
||||
const int64_t s1 = nb1/sizeof(dst_t);
|
||||
const int64_t s2 = nb2/sizeof(dst_t);
|
||||
const int64_t s3 = nb3/sizeof(dst_t);
|
||||
|
||||
if (ne_total > 0) {
|
||||
k_set_rows<<<grid_size, block_size, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
s01, s02, s03,
|
||||
s10, s11, s12,
|
||||
s1, s2, s3);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I64);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const int64_t * src1_d = (const int64_t *)src1->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
|
||||
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
set_rows_cuda(
|
||||
src0_d, src1_d, (float*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_F16) {
|
||||
set_rows_cuda(
|
||||
src0_d, src1_d, (half*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_BF16) {
|
||||
set_rows_cuda(
|
||||
src0_d, src1_d, (nv_bfloat16*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_Q4_0) {
|
||||
set_rows_cuda_quant<block_q4_0, QK4_0, quantize_f32_q4_0_block>(
|
||||
src0_d, src1_d, (block_q4_0*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_Q4_1) {
|
||||
set_rows_cuda_quant<block_q4_1, QK4_1, quantize_f32_q4_1_block>(
|
||||
src0_d, src1_d, (block_q4_1*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_Q5_0) {
|
||||
set_rows_cuda_quant<block_q5_0, QK5_0, quantize_f32_q5_0_block>(
|
||||
src0_d, src1_d, (block_q5_0*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_Q5_1) {
|
||||
set_rows_cuda_quant<block_q5_1, QK5_1, quantize_f32_q5_1_block>(
|
||||
src0_d, src1_d, (block_q5_1*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_Q8_0) {
|
||||
set_rows_cuda_quant<block_q8_0, QK8_0, quantize_f32_q8_0_block>(
|
||||
src0_d, src1_d, (block_q8_0*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_IQ4_NL) {
|
||||
set_rows_cuda_quant<block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
|
||||
src0_d, src1_d, (block_iq4_nl*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else {
|
||||
GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
|
||||
}
|
||||
}
|
||||
7
ggml/src/ggml-cuda/set-rows.cuh
Normal file
7
ggml/src/ggml-cuda/set-rows.cuh
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_SET_ROWS_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -107,8 +107,11 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
||||
if (nc == 4) {
|
||||
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else if (nc == 3) {
|
||||
ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else {
|
||||
GGML_ABORT("Only support kernel size = 4 now.");
|
||||
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
|
||||
}
|
||||
} else {
|
||||
if (nc == 4) {
|
||||
@@ -116,8 +119,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
||||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
||||
ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else if (nc == 3) {
|
||||
const int64_t split_n_t = 32;
|
||||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
||||
ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else {
|
||||
GGML_ABORT("Only support kernel size = 4 right now.");
|
||||
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,11 +201,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
|
||||
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
|
||||
cudaStream_t stream) {
|
||||
const int threads = 128;
|
||||
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
||||
if (src3_nb1 == sizeof(float)) {
|
||||
// Mamba-2
|
||||
if (d_state == 128) {
|
||||
const int threads = 128;
|
||||
GGML_ASSERT(d_state % threads == 0);
|
||||
// NOTE: can be any power of two between 4 and 64
|
||||
const int splitH = 16;
|
||||
@@ -215,10 +215,21 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
||||
} else if (d_state == 256) { // Falcon-H1
|
||||
const int threads = 256;
|
||||
// NOTE: can be any power of two between 8 and 64
|
||||
const int splitH = 16;
|
||||
GGML_ASSERT(head_dim % splitH == 0);
|
||||
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
|
||||
ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
||||
} else {
|
||||
GGML_ABORT("doesn't support d_state!=128.");
|
||||
GGML_ABORT("doesn't support d_state!=(128 or 256).");
|
||||
}
|
||||
} else {
|
||||
const int threads = 128;
|
||||
// Mamba-1
|
||||
GGML_ASSERT(n_head % threads == 0);
|
||||
GGML_ASSERT(head_dim == 1);
|
||||
|
||||
@@ -83,6 +83,10 @@ static __device__ __forceinline__ float op_log(float x) {
|
||||
return logf(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_elu(float x) {
|
||||
return (x > 0.f) ? x : expm1f(x);
|
||||
}
|
||||
|
||||
template <float (*op)(float), typename T>
|
||||
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
@@ -196,6 +200,9 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_log>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_elu>(ctx, dst);
|
||||
}
|
||||
/* gated ops */
|
||||
|
||||
template <float (*op)(float), typename T>
|
||||
|
||||
@@ -59,6 +59,8 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
33
ggml/src/ggml-cuda/vendors/hip.h
vendored
33
ggml/src/ggml-cuda/vendors/hip.h
vendored
@@ -10,9 +10,6 @@
|
||||
#include "rocblas/rocblas.h"
|
||||
#endif // __HIP_PLATFORM_AMD__
|
||||
|
||||
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_OP_N HIPBLAS_OP_N
|
||||
@@ -30,7 +27,6 @@
|
||||
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
||||
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
||||
#define cublasCreate hipblasCreate
|
||||
#define cublasDestroy hipblasDestroy
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
@@ -42,7 +38,6 @@
|
||||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
|
||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
@@ -144,6 +139,20 @@
|
||||
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
||||
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 70000000
|
||||
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
|
||||
#define cublasComputeType_t hipblasComputeType_t
|
||||
#define cudaDataType_t hipDataType
|
||||
#else
|
||||
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
||||
#define cublasComputeType_t hipblasDatatype_t
|
||||
#define cudaDataType_t hipblasDatatype_t
|
||||
#endif
|
||||
|
||||
#define __CUDA_ARCH__ 1300
|
||||
|
||||
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
|
||||
@@ -151,7 +160,19 @@
|
||||
#endif
|
||||
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
|
||||
#define CDNA
|
||||
#define CDNA // For the entire family
|
||||
#endif
|
||||
|
||||
#if defined(__gfx942__)
|
||||
#define CDNA3
|
||||
#endif
|
||||
|
||||
#if defined(__gfx90a__)
|
||||
#define CDNA2
|
||||
#endif
|
||||
|
||||
#if defined(__gfx908__)
|
||||
#define CDNA1
|
||||
#endif
|
||||
|
||||
#if defined(__GFX12__)
|
||||
|
||||
4
ggml/src/ggml-cuda/vendors/musa.h
vendored
4
ggml/src/ggml-cuda/vendors/musa.h
vendored
@@ -13,7 +13,7 @@
|
||||
#define CUBLAS_OP_N MUBLAS_OP_N
|
||||
#define CUBLAS_OP_T MUBLAS_OP_T
|
||||
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
|
||||
#define CUDA_R_16F MUSA_R_16F
|
||||
#define CUDA_R_16BF MUSA_R_16BF
|
||||
#define CUDA_R_32F MUSA_R_32F
|
||||
@@ -29,7 +29,7 @@
|
||||
#define cublasSgemm mublasSgemm
|
||||
#define cublasStatus_t mublasStatus_t
|
||||
#define cublasOperation_t mublasOperation_t
|
||||
#define cublasGetStatusString mublasStatus_to_string
|
||||
#define cublasGetStatusString mublasGetStatusString
|
||||
#define cudaDataType_t musaDataType_t
|
||||
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
||||
|
||||
@@ -73,6 +73,22 @@ static inline int ggml_up(int n, int m) {
|
||||
return (n + m - 1) & ~(m - 1);
|
||||
}
|
||||
|
||||
// TODO: move to ggml.h?
|
||||
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
|
||||
if (a->type != b->type) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (a->ne[i] != b->ne[i]) {
|
||||
return false;
|
||||
}
|
||||
if (a->nb[i] != b->nb[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
//
|
||||
// logging
|
||||
//
|
||||
|
||||
@@ -126,6 +126,7 @@ typedef struct {
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
uint64_t offs;
|
||||
uint64_t o1[8];
|
||||
} ggml_metal_kargs_bin;
|
||||
|
||||
typedef struct {
|
||||
@@ -240,7 +241,7 @@ typedef struct {
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
uint16_t n_head_log2;
|
||||
int32_t n_head_log2;
|
||||
float logit_softcap;
|
||||
} ggml_metal_kargs_flash_attn_ext;
|
||||
|
||||
@@ -377,8 +378,16 @@ typedef struct {
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne00_4;
|
||||
uint64_t nb01;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
float eps;
|
||||
int32_t nef1[3];
|
||||
int32_t nef2[3];
|
||||
int32_t nef3[3];
|
||||
uint64_t nbf1[3];
|
||||
uint64_t nbf2[3];
|
||||
uint64_t nbf3[3];
|
||||
} ggml_metal_kargs_rms_norm;
|
||||
|
||||
typedef struct {
|
||||
@@ -484,7 +493,7 @@ typedef struct {
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
uint32_t n_head_log2;
|
||||
int32_t n_head_log2;
|
||||
} ggml_metal_kargs_soft_max;
|
||||
|
||||
typedef struct {
|
||||
@@ -519,6 +528,7 @@ typedef struct {
|
||||
int64_t n_group;
|
||||
int64_t n_seq_tokens;
|
||||
int64_t n_seqs;
|
||||
int64_t s_off;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
|
||||
@@ -55,6 +55,12 @@ static struct ggml_backend_metal_device_context {
|
||||
bool has_residency_sets;
|
||||
bool has_bfloat;
|
||||
bool use_bfloat;
|
||||
bool use_fusion;
|
||||
|
||||
int debug_fusion;
|
||||
|
||||
// how many times a given op was fused
|
||||
uint64_t fuse_cnt[GGML_OP_COUNT];
|
||||
|
||||
size_t max_size;
|
||||
|
||||
@@ -69,6 +75,9 @@ static struct ggml_backend_metal_device_context {
|
||||
/*.has_residency_sets =*/ false,
|
||||
/*.has_bfloat =*/ false,
|
||||
/*.use_bfloat =*/ false,
|
||||
/*.use_fusion =*/ true,
|
||||
/*.debug_fusion =*/ 0,
|
||||
/*.fuse_cnt =*/ { 0 },
|
||||
/*.max_size =*/ 0,
|
||||
/*.name =*/ "",
|
||||
};
|
||||
@@ -83,16 +92,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
||||
|
||||
if (ctx->mtl_device == nil) {
|
||||
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
||||
}
|
||||
|
||||
if (ctx->mtl_device) {
|
||||
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
|
||||
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
|
||||
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
||||
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
|
||||
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
|
||||
#endif
|
||||
|
||||
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
@@ -103,6 +110,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
||||
#else
|
||||
ctx->use_bfloat = false;
|
||||
#endif
|
||||
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
||||
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
||||
ctx->debug_fusion = val ? atoi(val) : 0;
|
||||
}
|
||||
|
||||
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
|
||||
|
||||
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
||||
|
||||
@@ -122,6 +137,18 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
||||
ctx->mtl_device_ref_count--;
|
||||
|
||||
if (ctx->mtl_device_ref_count == 0) {
|
||||
if (ctx->debug_fusion > 0) {
|
||||
fprintf(stderr, "%s: fusion stats:\n", __func__);
|
||||
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
||||
if (ctx->fuse_cnt[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// note: cannot use ggml_log here
|
||||
fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx->mtl_lock) {
|
||||
[ctx->mtl_lock release];
|
||||
ctx->mtl_lock = nil;
|
||||
@@ -147,13 +174,27 @@ struct ggml_metal_kernel {
|
||||
|
||||
enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_ADD,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
|
||||
GGML_METAL_KERNEL_TYPE_SUB,
|
||||
GGML_METAL_KERNEL_TYPE_SUB_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
|
||||
GGML_METAL_KERNEL_TYPE_DIV,
|
||||
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
|
||||
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
||||
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
||||
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
||||
@@ -173,6 +214,12 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_SILU,
|
||||
GGML_METAL_KERNEL_TYPE_SILU_4,
|
||||
GGML_METAL_KERNEL_TYPE_ELU,
|
||||
GGML_METAL_KERNEL_TYPE_ABS,
|
||||
GGML_METAL_KERNEL_TYPE_SGN,
|
||||
GGML_METAL_KERNEL_TYPE_STEP,
|
||||
GGML_METAL_KERNEL_TYPE_HARDSWISH,
|
||||
GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
|
||||
GGML_METAL_KERNEL_TYPE_EXP,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
||||
@@ -212,6 +259,8 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
||||
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
||||
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
|
||||
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
|
||||
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_NORM,
|
||||
@@ -1129,13 +1178,27 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
||||
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
||||
@@ -1155,6 +1218,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
||||
@@ -1194,6 +1263,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||
@@ -1688,6 +1759,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_NEG:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_SGN:
|
||||
case GGML_UNARY_OP_STEP:
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
@@ -1875,9 +1952,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_metal_encode_node(
|
||||
static int ggml_metal_encode_node(
|
||||
ggml_backend_t backend,
|
||||
int idx,
|
||||
int idx_end,
|
||||
id<MTLComputeCommandEncoder> encoder,
|
||||
struct ggml_metal_mem_pool * mem_pool) {
|
||||
struct ggml_backend_metal_context * ctx = backend->context;
|
||||
@@ -1885,7 +1963,10 @@ static bool ggml_metal_encode_node(
|
||||
|
||||
struct ggml_cgraph * gf = ctx->gf;
|
||||
|
||||
struct ggml_tensor * node = ggml_graph_node(gf, idx);
|
||||
enum ggml_op ops[8];
|
||||
|
||||
struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
|
||||
struct ggml_tensor * node = nodes[0];
|
||||
|
||||
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
|
||||
|
||||
@@ -1895,7 +1976,7 @@ static bool ggml_metal_encode_node(
|
||||
struct ggml_tensor * dst = node;
|
||||
|
||||
if (ggml_is_empty(dst)) {
|
||||
return true;
|
||||
return 1;
|
||||
}
|
||||
|
||||
switch (dst->op) {
|
||||
@@ -1906,7 +1987,7 @@ static bool ggml_metal_encode_node(
|
||||
case GGML_OP_PERMUTE:
|
||||
{
|
||||
// noop -> next node
|
||||
} return true;
|
||||
} return 1;
|
||||
default:
|
||||
{
|
||||
} break;
|
||||
@@ -1973,6 +2054,8 @@ static bool ggml_metal_encode_node(
|
||||
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
||||
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
||||
|
||||
int n_fuse = 1;
|
||||
|
||||
#if 0
|
||||
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
||||
if (src0) {
|
||||
@@ -2044,37 +2127,15 @@ static bool ggml_metal_encode_node(
|
||||
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(src1));
|
||||
|
||||
const size_t offs = 0;
|
||||
|
||||
bool bcast_row = false;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
// src1 is a row
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
||||
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
|
||||
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
|
||||
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
bcast_row = true;
|
||||
} else {
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
|
||||
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
||||
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
||||
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
ggml_metal_kargs_bin args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
@@ -2101,12 +2162,119 @@ static bool ggml_metal_encode_node(
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.offs =*/ offs,
|
||||
/*.o1 =*/ { offs_src1 },
|
||||
};
|
||||
|
||||
// c[0] = add(a, b[0])
|
||||
// c[1] = add(c[0], b[1])
|
||||
// c[2] = add(c[1], b[2])
|
||||
// ...
|
||||
if (ctx_dev->use_fusion) {
|
||||
ops[0] = GGML_OP_ADD;
|
||||
ops[1] = GGML_OP_ADD;
|
||||
ops[2] = GGML_OP_ADD;
|
||||
ops[3] = GGML_OP_ADD;
|
||||
ops[4] = GGML_OP_ADD;
|
||||
ops[5] = GGML_OP_ADD;
|
||||
ops[6] = GGML_OP_ADD;
|
||||
ops[7] = GGML_OP_ADD;
|
||||
|
||||
size_t offs_fuse;
|
||||
id<MTLBuffer> id_fuse;
|
||||
|
||||
// note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
|
||||
// across splits. idx_end indicates the last node in the current split
|
||||
for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
||||
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
|
||||
break;
|
||||
}
|
||||
|
||||
// b[0] === b[1] === ...
|
||||
if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
|
||||
break;
|
||||
}
|
||||
|
||||
// only fuse nodes if src1 is in the same Metal buffer
|
||||
id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse);
|
||||
if (id_fuse != id_src1) {
|
||||
break;
|
||||
}
|
||||
|
||||
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
|
||||
|
||||
args.o1[n_fuse + 1] = offs_fuse;
|
||||
}
|
||||
|
||||
++n_fuse;
|
||||
|
||||
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
|
||||
GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
// src1 is a row
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
switch (n_fuse) {
|
||||
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
|
||||
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
|
||||
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
|
||||
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
|
||||
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
|
||||
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
|
||||
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
|
||||
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
|
||||
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
|
||||
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
bcast_row = true;
|
||||
} else {
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
switch (n_fuse) {
|
||||
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
|
||||
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
|
||||
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
|
||||
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
|
||||
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
|
||||
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
|
||||
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
|
||||
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
||||
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
||||
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
if (n_fuse > 1) {
|
||||
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_src1 offset:0 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
if (bcast_row) {
|
||||
@@ -2114,7 +2282,11 @@ static bool ggml_metal_encode_node(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} else {
|
||||
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
||||
int nth = 32;
|
||||
|
||||
while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
}
|
||||
@@ -2239,12 +2411,13 @@ static bool ggml_metal_encode_node(
|
||||
/*.nb2 =*/ pnb2,
|
||||
/*.nb3 =*/ pnb3,
|
||||
/*.offs =*/ offs,
|
||||
/*.o1 =*/ { offs_src1},
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_src1 offset:0 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
||||
@@ -2256,7 +2429,9 @@ static bool ggml_metal_encode_node(
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
float scale;
|
||||
memcpy(&scale, dst->op_params, sizeof(scale));
|
||||
float bias;
|
||||
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
|
||||
memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float));
|
||||
|
||||
int64_t n = ggml_nelements(dst);
|
||||
|
||||
@@ -2273,6 +2448,7 @@ static bool ggml_metal_encode_node(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
||||
[encoder setBytes:&bias length:sizeof(bias) atIndex:3];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
@@ -2436,6 +2612,78 @@ static bool ggml_metal_encode_node(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_ABS:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_SGN:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_STEP:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_EXP:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||
@@ -2671,7 +2919,7 @@ static bool ggml_metal_encode_node(
|
||||
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
|
||||
if (!h_src0) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
|
||||
return false;
|
||||
return 0;
|
||||
}
|
||||
|
||||
offs_src0 = 0;
|
||||
@@ -2893,6 +3141,7 @@ static bool ggml_metal_encode_node(
|
||||
/*.n_group =*/ n_group,
|
||||
/*.n_seq_tokens =*/ n_seq_tokens,
|
||||
/*.n_seqs =*/ n_seqs,
|
||||
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
@@ -2921,12 +3170,22 @@ static bool ggml_metal_encode_node(
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
||||
|
||||
// One shared memory bucket for each simd group in the threadgroup
|
||||
// NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
|
||||
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
||||
if (d_state >= 32) {
|
||||
GGML_ASSERT((int64_t)(d_state / 32) <= 32);
|
||||
const int64_t shmem_size = 32;
|
||||
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
|
||||
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
|
||||
}
|
||||
|
||||
if (ne30 == 1) {
|
||||
// Mamba-2
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
||||
} else {
|
||||
GGML_ASSERT(d_inner == 1);
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
@@ -3547,7 +3806,7 @@ static bool ggml_metal_encode_node(
|
||||
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
|
||||
if (!h_src1) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
|
||||
return false;
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int64_t neh0 = ne0;
|
||||
@@ -3563,7 +3822,7 @@ static bool ggml_metal_encode_node(
|
||||
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
|
||||
if (!h_dst) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
|
||||
return false;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// tokens per expert
|
||||
@@ -3571,7 +3830,7 @@ static bool ggml_metal_encode_node(
|
||||
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
||||
if (!h_tpe) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
|
||||
return false;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// id map
|
||||
@@ -3580,7 +3839,7 @@ static bool ggml_metal_encode_node(
|
||||
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
|
||||
if (!h_ids) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
|
||||
return false;
|
||||
return 0;
|
||||
}
|
||||
|
||||
{
|
||||
@@ -4012,12 +4271,95 @@ static bool ggml_metal_encode_node(
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(src0));
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
|
||||
ggml_metal_kargs_rms_norm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne00_4 =*/ ne00/4,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.eps =*/ eps,
|
||||
/*.nef1 =*/ { ne01 },
|
||||
/*.nef2 =*/ { ne02 },
|
||||
/*.nef3 =*/ { ne03 },
|
||||
/*.nbf1 =*/ { nb01 },
|
||||
/*.nbf2 =*/ { nb02 },
|
||||
/*.nbf3 =*/ { nb03 },
|
||||
};
|
||||
|
||||
size_t offs_fuse[2] = { 0, 0 };
|
||||
id<MTLBuffer> id_fuse[2] = { id_src0, id_src0 };
|
||||
|
||||
// d[0] = rms_norm(a)
|
||||
// d[1] = mul(d[0], b)
|
||||
// d[2] = add(d[1], c)
|
||||
if (ctx_dev->use_fusion) {
|
||||
ops[0] = GGML_OP_RMS_NORM;
|
||||
ops[1] = GGML_OP_MUL;
|
||||
ops[2] = GGML_OP_ADD;
|
||||
|
||||
for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
||||
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) {
|
||||
break;
|
||||
}
|
||||
|
||||
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
|
||||
|
||||
id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
|
||||
|
||||
args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
|
||||
args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
|
||||
args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
|
||||
|
||||
args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
|
||||
args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
|
||||
args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
|
||||
}
|
||||
|
||||
++n_fuse;
|
||||
|
||||
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
|
||||
if (n_fuse == 2) {
|
||||
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
|
||||
}
|
||||
if (n_fuse == 3) {
|
||||
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (n_fuse > 1) {
|
||||
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> pipeline;
|
||||
|
||||
switch (n_fuse) {
|
||||
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
|
||||
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
|
||||
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
|
||||
default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
|
||||
}
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
@@ -4028,23 +4370,16 @@ static bool ggml_metal_encode_node(
|
||||
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
||||
nth = MIN(nth, ne00/4);
|
||||
|
||||
ggml_metal_kargs_rms_norm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne00_4 =*/ ne00/4,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.eps =*/ eps,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
|
||||
[encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_L2_NORM:
|
||||
{
|
||||
@@ -5439,7 +5774,7 @@ static bool ggml_metal_encode_node(
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
return n_fuse;
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_metal_graph_compute(
|
||||
@@ -5945,20 +6280,26 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
||||
ggml_metal_mem_pool_reset(mem_pool);
|
||||
|
||||
for (int idx = node_start; idx < node_end; ++idx) {
|
||||
for (int idx = node_start; idx < node_end;) {
|
||||
if (should_capture) {
|
||||
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
||||
}
|
||||
|
||||
const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
|
||||
const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
|
||||
if (idx + res > node_end) {
|
||||
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/14849");
|
||||
}
|
||||
|
||||
if (should_capture) {
|
||||
[encoder popDebugGroup];
|
||||
}
|
||||
|
||||
if (!res) {
|
||||
if (res == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
idx += res;
|
||||
}
|
||||
|
||||
[encoder endEncoding];
|
||||
|
||||
@@ -832,7 +832,8 @@ enum ggml_sort_order {
|
||||
// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
|
||||
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
||||
// cons: not very efficient
|
||||
kernel void kernel_add(
|
||||
template <int F>
|
||||
kernel void kernel_add_fuse_impl(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
@@ -848,16 +849,39 @@ kernel void kernel_add(
|
||||
const int i12 = i02%args.ne12;
|
||||
const int i11 = i01%args.ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
|
||||
device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
|
||||
|
||||
device const float * src1_ptr[F];
|
||||
for (short j = 0; j < F; ++j) {
|
||||
src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
||||
}
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10));
|
||||
|
||||
float res = src0_ptr[i0];
|
||||
|
||||
#pragma unroll
|
||||
for (short j = 0; j < F; ++j) {
|
||||
res += src1_ptr[j][i10];
|
||||
}
|
||||
|
||||
dst_ptr[i0] = res;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
|
||||
|
||||
template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
|
||||
template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
|
||||
template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
|
||||
template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
|
||||
template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
|
||||
template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
|
||||
template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
|
||||
template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
|
||||
|
||||
kernel void kernel_sub(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
@@ -875,7 +899,7 @@ kernel void kernel_sub(
|
||||
const int i11 = i01%args.ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
@@ -900,9 +924,9 @@ kernel void kernel_mul(
|
||||
const int i12 = i02%args.ne12;
|
||||
const int i11 = i01%args.ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
@@ -926,9 +950,9 @@ kernel void kernel_div(
|
||||
const int i12 = i02%args.ne12;
|
||||
const int i11 = i01%args.ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
@@ -970,60 +994,161 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat
|
||||
|
||||
// assumption: src1 is a row
|
||||
// broadcast src1 into src0
|
||||
kernel void kernel_add_row(
|
||||
template <short F>
|
||||
kernel void kernel_add_row_c4_fuse_impl(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
|
||||
const uint nb = args.ne00/4;
|
||||
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
||||
const uint i = tpig % nb;
|
||||
|
||||
device const float4 * src0_row = (device const float4 *) (src0);
|
||||
device float4 * dst_row = (device float4 *) (dst);
|
||||
|
||||
device const float4 * src1_row[F];
|
||||
for (short j = 0; j < F; ++j) {
|
||||
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
||||
}
|
||||
|
||||
float4 res = src0_row[tpig];
|
||||
|
||||
#pragma unroll(F)
|
||||
for (short j = 0; j < F; ++j) {
|
||||
res += src1_row[j][i];
|
||||
}
|
||||
|
||||
dst_row[tpig] = res;
|
||||
}
|
||||
|
||||
kernel void kernel_sub_row(
|
||||
typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
|
||||
|
||||
template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
|
||||
template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
|
||||
template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
|
||||
template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
|
||||
template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
|
||||
template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
|
||||
template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
|
||||
template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
|
||||
|
||||
template <short F>
|
||||
kernel void kernel_sub_row_c4_fuse_impl(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
|
||||
const uint nb = args.ne00/4;
|
||||
dst[tpig] = src0[tpig] - src1[tpig % nb];
|
||||
const uint i = tpig % nb;
|
||||
|
||||
device const float4 * src0_row = (device const float4 *) (src0);
|
||||
device float4 * dst_row = (device float4 *) (dst);
|
||||
|
||||
device const float4 * src1_row[F];
|
||||
for (short j = 0; j < F; ++j) {
|
||||
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
||||
}
|
||||
|
||||
float4 res = src0_row[tpig];
|
||||
|
||||
#pragma unroll(F)
|
||||
for (short j = 0; j < F; ++j) {
|
||||
res -= src1_row[j][i];
|
||||
}
|
||||
|
||||
dst_row[tpig] = res;
|
||||
}
|
||||
|
||||
kernel void kernel_mul_row(
|
||||
typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
|
||||
|
||||
template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
|
||||
|
||||
template <short F>
|
||||
kernel void kernel_mul_row_c4_fuse_impl(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
|
||||
const uint nb = args.ne00/4;
|
||||
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
||||
const uint i = tpig % nb;
|
||||
|
||||
device const float4 * src0_row = (device const float4 *) (src0);
|
||||
device float4 * dst_row = (device float4 *) (dst);
|
||||
|
||||
device const float4 * src1_row[F];
|
||||
for (short j = 0; j < F; ++j) {
|
||||
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
||||
}
|
||||
|
||||
float4 res = src0_row[tpig];
|
||||
|
||||
#pragma unroll(F)
|
||||
for (short j = 0; j < F; ++j) {
|
||||
res *= src1_row[j][i];
|
||||
}
|
||||
|
||||
dst_row[tpig] = res;
|
||||
}
|
||||
|
||||
kernel void kernel_div_row(
|
||||
typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
|
||||
|
||||
template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
|
||||
|
||||
template <short F>
|
||||
kernel void kernel_div_row_c4_fuse_impl(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
|
||||
const uint nb = args.ne00/4;
|
||||
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
||||
const uint i = tpig % nb;
|
||||
|
||||
device const float4 * src0_row = (device const float4 *) (src0);
|
||||
device float4 * dst_row = (device float4 *) (dst);
|
||||
|
||||
device const float4 * src1_row[F];
|
||||
for (short j = 0; j < F; ++j) {
|
||||
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
||||
}
|
||||
|
||||
float4 res = src0_row[tpig];
|
||||
|
||||
#pragma unroll(F)
|
||||
for (short j = 0; j < F; ++j) {
|
||||
res /= src1_row[j][i];
|
||||
}
|
||||
|
||||
dst_row[tpig] = res;
|
||||
}
|
||||
|
||||
typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
|
||||
|
||||
template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
|
||||
|
||||
kernel void kernel_scale(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant float & scale,
|
||||
constant float & bias,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * scale;
|
||||
dst[tpig] = src0[tpig] * scale + bias;
|
||||
}
|
||||
|
||||
kernel void kernel_scale_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
constant float & scale,
|
||||
constant float & bias,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * scale;
|
||||
dst[tpig] = src0[tpig] * scale + bias;
|
||||
}
|
||||
|
||||
kernel void kernel_clamp(
|
||||
@@ -1197,6 +1322,51 @@ kernel void kernel_neg(
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_abs(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = fabs(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sgn(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f);
|
||||
}
|
||||
|
||||
kernel void kernel_step(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_hardswish(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
kernel void kernel_hardsigmoid(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
kernel void kernel_exp(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = exp(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_reglu(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
@@ -1653,10 +1823,16 @@ kernel void kernel_ssm_scan_f32(
|
||||
device const void * src5,
|
||||
device const void * src6,
|
||||
device float * dst,
|
||||
threadgroup float * shared [[threadgroup(0)]],
|
||||
constant ggml_metal_kargs_ssm_scan & args,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgptg[[simdgroups_per_threadgroup]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]) {
|
||||
|
||||
const int64_t i0 = tpitg.x;
|
||||
const int64_t i1 = 0;
|
||||
const int64_t ir = tgpig.x; // current head
|
||||
const int64_t i3 = tgpig.y; // current seq
|
||||
@@ -1671,41 +1847,88 @@ kernel void kernel_ssm_scan_f32(
|
||||
const int64_t ng = args.n_group;
|
||||
const int64_t n_t = args.n_seq_tokens;
|
||||
|
||||
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
||||
const int64_t s_off = args.s_off;
|
||||
|
||||
device const int32_t * ids = (device const int32_t *) src6;
|
||||
|
||||
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||
const int64_t i = i0 + i1*nc;
|
||||
float s0 = s0_buff[i];
|
||||
float s = s_buff[i];
|
||||
|
||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
|
||||
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
||||
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
||||
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
|
||||
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
|
||||
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
||||
|
||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
||||
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
|
||||
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
||||
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
||||
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
|
||||
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
|
||||
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
|
||||
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
|
||||
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
|
||||
|
||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||
const float x_dt = x[0] * dt_soft_plus;
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||
const int64_t i = i0 + i1*nc;
|
||||
const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
||||
sumf += state * C[i0];
|
||||
s[i] = state;
|
||||
const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
||||
s = state;
|
||||
|
||||
// Parallel sum: This relies on the fact that this kernel will be
|
||||
// dispatched with each threadgroup having (d_state, 1, 1) threads which
|
||||
// are subdivided into SIMD groups of size `sgptg`. The goal is to
|
||||
// compute y = sum({state * C[i] for i in range(d_state)}).
|
||||
// To parallelize this effectively, we first use simd_sum over each SIMD
|
||||
// group to compute the sum of each SIMD group, then place the result in
|
||||
// the SIMD group's indexed bucket in the shared memory. We then sum
|
||||
// over the individual group sums to compute the final sum.
|
||||
|
||||
// Computed for each thread
|
||||
float sumf = state * C[i0];
|
||||
|
||||
// Sum the threads in the simd group => simd sum
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (sgptg > 1) {
|
||||
|
||||
// Once per simd group, place the group sum into the shared buffer
|
||||
if (tiisg == 0) {
|
||||
shared[sgitg] = sumf;
|
||||
}
|
||||
|
||||
// Wait for all threads in the threadgroup to reach this point. This
|
||||
// ensures that all elements of the shared buffer are populated with the
|
||||
// sum of the individual simd groups.
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// For simd group 0 at indices < num simd groups, extract the shared
|
||||
// simd sum
|
||||
sumf = 0.0f;
|
||||
if (sgitg == 0) {
|
||||
if (tiisg < sgptg) {
|
||||
sumf = shared[tiisg];
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
y[0] = sumf;
|
||||
}
|
||||
}
|
||||
} else if (tiisg == 0) {
|
||||
y[0] = sumf;
|
||||
}
|
||||
|
||||
y[0] = sumf;
|
||||
|
||||
// recurse
|
||||
s0 = s;
|
||||
}
|
||||
|
||||
// Assign the final state to the output buffer
|
||||
s_buff[i] = s;
|
||||
}
|
||||
|
||||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
||||
// TODO: optimize (e.g. by parallelizing over d_state)
|
||||
kernel void kernel_ssm_scan_f32_group(
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
@@ -1715,10 +1938,16 @@ kernel void kernel_ssm_scan_f32_group(
|
||||
device const void * src5,
|
||||
device const void * src6,
|
||||
device float * dst,
|
||||
threadgroup float * shared [[threadgroup(0)]],
|
||||
constant ggml_metal_kargs_ssm_scan & args,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgptg[[simdgroups_per_threadgroup]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]) {
|
||||
|
||||
const int64_t i0 = tpitg.x;
|
||||
const int64_t i1 = tgpig.x;
|
||||
const int64_t ir = tgpig.y; // current head
|
||||
const int64_t i3 = tgpig.z; // current seq
|
||||
@@ -1733,38 +1962,81 @@ kernel void kernel_ssm_scan_f32_group(
|
||||
const int64_t ng = args.n_group;
|
||||
const int64_t n_t = args.n_seq_tokens;
|
||||
|
||||
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
||||
const int64_t s_off = args.s_off;
|
||||
|
||||
device const int32_t * ids = (device const int32_t *) src6;
|
||||
|
||||
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||
const int64_t i = i0 + i1*nc;
|
||||
float s0 = s0_buff[i];
|
||||
float s = s_buff[i];
|
||||
|
||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
||||
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
||||
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
||||
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
|
||||
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
|
||||
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
||||
|
||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
||||
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
||||
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
||||
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
||||
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
|
||||
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
|
||||
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
|
||||
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
|
||||
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
|
||||
|
||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||
const float x_dt = x[0] * dt_soft_plus;
|
||||
const float dA = exp(dt_soft_plus * A[0]);
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||
const int64_t i = i0 + i1*nc;
|
||||
const float state = (s0[i] * dA) + (B[i0] * x_dt);
|
||||
sumf += state * C[i0];
|
||||
s[i] = state;
|
||||
const float state = (s0 * dA) + (B[i0] * x_dt);
|
||||
s = state;
|
||||
|
||||
// Parallel sum: This relies on the fact that this kernel will be
|
||||
// dispatched with each threadgroup having (d_state, 1, 1) threads which
|
||||
// are subdivided into SIMD groups of size `sgptg`. The goal is to
|
||||
// compute y = sum({state * C[i] for i in range(d_state)}).
|
||||
// To parallelize this effectively, we first use simd_sum over each SIMD
|
||||
// group to compute the sum of each SIMD group, then place the result in
|
||||
// the SIMD group's indexed bucket in the shared memory. We then sum
|
||||
// over the individual group sums to compute the final sum.
|
||||
|
||||
// Computed for each thread
|
||||
float sumf = state * C[i0];
|
||||
|
||||
// Sum the threads in the simd group => simd sum
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
// Once per simd group, place the group sum into the shared buffer
|
||||
if (tiisg == 0) {
|
||||
shared[sgitg] = sumf;
|
||||
}
|
||||
|
||||
y[0] = sumf;
|
||||
// Wait for all threads in the threadgroup to reach this point. This
|
||||
// ensures that all elements of the shared buffer are populated with the
|
||||
// sum of the individual simd groups.
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// For simd group 0 at indices < num simd groups, extract the shared
|
||||
// simd sum
|
||||
sumf = 0.0f;
|
||||
if (sgitg == 0) {
|
||||
if (tiisg < sgptg) {
|
||||
sumf = shared[tiisg];
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
y[0] = sumf;
|
||||
}
|
||||
}
|
||||
|
||||
// recurse
|
||||
s0 = s;
|
||||
}
|
||||
|
||||
// Assign the final state to the output buffer
|
||||
s_buff[i] = s;
|
||||
}
|
||||
|
||||
kernel void kernel_rwkv_wkv6_f32(
|
||||
@@ -2069,26 +2341,39 @@ kernel void kernel_norm(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rms_norm(
|
||||
// F == 1 : rms_norm (no fuse)
|
||||
// F == 2 : rms_norm + mul
|
||||
// F == 3 : rms_norm + mul + add
|
||||
template <short F>
|
||||
kernel void kernel_rms_norm_fuse_impl(
|
||||
constant ggml_metal_kargs_rms_norm & args,
|
||||
device const char * src0,
|
||||
device const char * src1_0,
|
||||
device const char * src1_1,
|
||||
device char * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
|
||||
const int i01 = tgpig.x;
|
||||
const int i02 = tgpig.y;
|
||||
const int i03 = tgpig.z;
|
||||
|
||||
device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
|
||||
|
||||
device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
|
||||
device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
|
||||
sumf += dot(x[i00], x[i00]);
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
@@ -2107,12 +2392,26 @@ kernel void kernel_rms_norm(
|
||||
const float mean = sumf/args.ne00;
|
||||
const float scale = 1.0f/sqrt(mean + args.eps);
|
||||
|
||||
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
|
||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||
y[i00] = x[i00] * scale;
|
||||
device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
||||
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
|
||||
if (F == 1) {
|
||||
y[i00] = (x[i00]*scale);
|
||||
}
|
||||
if (F == 2) {
|
||||
y[i00] = (x[i00]*scale)*f0[i00];
|
||||
}
|
||||
if (F == 3) {
|
||||
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
|
||||
|
||||
template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
|
||||
template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
|
||||
template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
|
||||
|
||||
kernel void kernel_l2_norm(
|
||||
constant ggml_metal_kargs_l2_norm & args,
|
||||
device const char * src0,
|
||||
|
||||
@@ -34,8 +34,12 @@ if (MUSAToolkit_FOUND)
|
||||
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
||||
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
||||
file(GLOB SRCS "../ggml-musa/*.cu")
|
||||
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
||||
|
||||
if (GGML_MUSA_MUDNN_COPY)
|
||||
file(GLOB SRCS "../ggml-musa/*.cu")
|
||||
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
||||
add_compile_definitions(GGML_MUSA_MUDNN_COPY)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_FA_ALL_QUANTS)
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
|
||||
@@ -72,6 +76,10 @@ if (MUSAToolkit_FOUND)
|
||||
add_compile_definitions(GGML_USE_MUSA)
|
||||
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
|
||||
|
||||
if (GGML_MUSA_GRAPHS)
|
||||
add_compile_definitions(GGML_MUSA_GRAPHS)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_FORCE_MMQ)
|
||||
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
||||
endif()
|
||||
@@ -97,10 +105,16 @@ if (MUSAToolkit_FOUND)
|
||||
endif()
|
||||
|
||||
if (GGML_STATIC)
|
||||
# TODO: mudnn has not provided static libraries yet
|
||||
target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
|
||||
# TODO: mudnn has not provided static libraries yet
|
||||
# if (GGML_MUSA_MUDNN_COPY)
|
||||
# target_link_libraries(ggml-musa PRIVATE mudnn_static)
|
||||
# endif()
|
||||
else()
|
||||
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn)
|
||||
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
|
||||
if (GGML_MUSA_MUDNN_COPY)
|
||||
target_link_libraries(ggml-musa PRIVATE mudnn)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_NO_VMM)
|
||||
|
||||
@@ -88,6 +88,7 @@ set(GGML_OPENCL_KERNELS
|
||||
rms_norm
|
||||
rope
|
||||
scale
|
||||
set_rows
|
||||
sigmoid
|
||||
silu
|
||||
softmax_4_f32
|
||||
@@ -103,6 +104,9 @@ set(GGML_OPENCL_KERNELS
|
||||
tanh
|
||||
pad
|
||||
repeat
|
||||
mul_mat_f16_f32
|
||||
conv2d
|
||||
conv2d_f16_f32
|
||||
)
|
||||
|
||||
foreach (K ${GGML_OPENCL_KERNELS})
|
||||
|
||||
@@ -333,6 +333,7 @@ struct ggml_backend_opencl_context {
|
||||
size_t max_alloc_size;
|
||||
bool fp16_support;
|
||||
bool has_vector_subgroup_broadcast;
|
||||
bool disable_fusion;
|
||||
ggml_cl_compiler_version adreno_cl_compiler_version;
|
||||
|
||||
int adreno_wave_size;
|
||||
@@ -351,6 +352,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_program program_gemv_noshuffle_general;
|
||||
cl_program program_gemv_noshuffle;
|
||||
cl_program program_get_rows;
|
||||
cl_program program_set_rows;
|
||||
cl_program program_glu;
|
||||
cl_program program_im2col_f16;
|
||||
cl_program program_im2col_f32;
|
||||
@@ -367,6 +369,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_program program_mul_mv_f16_f32;
|
||||
cl_program program_mul_mv_f32_f32;
|
||||
cl_program program_mul;
|
||||
cl_program program_mul_mat_f16_f32_tiled;
|
||||
cl_program program_div;
|
||||
cl_program program_sub;
|
||||
cl_program program_norm;
|
||||
@@ -388,6 +391,9 @@ struct ggml_backend_opencl_context {
|
||||
cl_program program_tanh;
|
||||
cl_program program_upscale;
|
||||
cl_program program_concat;
|
||||
cl_program program_conv_2d_f16;
|
||||
cl_program program_conv_2d_f32;
|
||||
cl_program program_conv_2d_f16_f32;
|
||||
cl_program program_tsembd;
|
||||
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
|
||||
|
||||
@@ -406,12 +412,13 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
|
||||
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
|
||||
cl_kernel kernel_norm;
|
||||
cl_kernel kernel_rms_norm;
|
||||
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
|
||||
cl_kernel kernel_group_norm;
|
||||
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
|
||||
cl_kernel kernel_soft_max, kernel_soft_max_4;
|
||||
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
|
||||
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
|
||||
cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
|
||||
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
|
||||
cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16;
|
||||
cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
|
||||
@@ -420,6 +427,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_mul_mat_f16_f32_1row;
|
||||
cl_kernel kernel_mul_mat_f16_f32;
|
||||
cl_kernel kernel_mul_mat_f16_f32_l4;
|
||||
cl_kernel kernel_mul_mat_f16_f32_tiled;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
|
||||
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
|
||||
@@ -437,6 +445,9 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_upscale_bilinear;
|
||||
cl_kernel kernel_concat_f32_contiguous;
|
||||
cl_kernel kernel_concat_f32_non_contiguous;
|
||||
cl_kernel kernel_conv_2d_f16;
|
||||
cl_kernel kernel_conv_2d_f32;
|
||||
cl_kernel kernel_conv_2d_f16_f32;
|
||||
cl_kernel kernel_timestep_embedding;
|
||||
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
|
||||
|
||||
@@ -529,6 +540,16 @@ struct ggml_backend_opencl_context {
|
||||
fclose(ftrace);
|
||||
}
|
||||
|
||||
size_t get_kernel_workgroup_size(cl_kernel kernel) const {
|
||||
size_t workgroup_size = 0;
|
||||
size_t ret_size = 0;
|
||||
CL_CHECK(
|
||||
clGetKernelWorkGroupInfo(kernel, device, CL_KERNEL_WORK_GROUP_SIZE,
|
||||
sizeof(size_t), &workgroup_size, &ret_size));
|
||||
GGML_ASSERT(sizeof(size_t) == ret_size);
|
||||
return workgroup_size;
|
||||
}
|
||||
|
||||
void enqueue_ndrange_kernel(cl_kernel kernel, cl_uint work_dim, size_t *global_work_size, size_t *local_work_size, const ggml_tensor * tensor) {
|
||||
#ifdef GGML_OPENCL_PROFILING
|
||||
cl_event evt;
|
||||
@@ -1003,6 +1024,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mat_f16_f32_tiled
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mat_f16_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mat_f16_f32.cl");
|
||||
#endif
|
||||
backend_ctx->program_mul_mat_f16_f32_tiled =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_tiled = clCreateKernel(backend_ctx->program_mul_mat_f16_f32_tiled, "mul_mat_f16_f32", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -1064,7 +1101,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
backend_ctx->program_rms_norm =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm_mul", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
@@ -1431,6 +1469,64 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
}
|
||||
}
|
||||
|
||||
// set_rows
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "set_rows.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("set_rows.cl");
|
||||
#endif
|
||||
backend_ctx->program_set_rows =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_set_rows_f32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_set_rows_f16 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// conv2d
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "conv2d.cl.h"
|
||||
};
|
||||
const std::string kernel_src_f16_f32 {
|
||||
#include "conv2d_f16_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("conv2d.cl");
|
||||
const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl");
|
||||
#endif
|
||||
if (!kernel_src.empty()) {
|
||||
backend_ctx->program_conv_2d_f16 =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str());
|
||||
CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
backend_ctx->program_conv_2d_f32 =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
} else {
|
||||
GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n");
|
||||
backend_ctx->program_conv_2d_f16 = nullptr;
|
||||
backend_ctx->kernel_conv_2d_f16 = nullptr;
|
||||
backend_ctx->program_conv_2d_f32 = nullptr;
|
||||
backend_ctx->kernel_conv_2d_f32 = nullptr;
|
||||
}
|
||||
if (!kernel_src_f16_f32.empty()) {
|
||||
backend_ctx->program_conv_2d_f16_f32 =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts);
|
||||
CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
} else {
|
||||
GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n");
|
||||
backend_ctx->program_conv_2d_f16_f32 = nullptr;
|
||||
backend_ctx->kernel_conv_2d_f16_f32 = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// mul_mv_id_q4_0_f32_8x_flat
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -2016,6 +2112,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
|
||||
backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr;
|
||||
|
||||
dev_ctx->backend_ctx = backend_ctx.release();
|
||||
return dev_ctx->backend_ctx;
|
||||
}
|
||||
@@ -2185,7 +2283,45 @@ static void sync_with_other_backends(ggml_backend_t backend) {
|
||||
sync_with_other_backends(backend_ctx);
|
||||
}
|
||||
|
||||
static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
|
||||
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
|
||||
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
|
||||
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
|
||||
|
||||
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
|
||||
|
||||
// rms_norm only supports f32
|
||||
if (mul->src[0]->type != GGML_TYPE_F32 ||
|
||||
mul->src[1]->type != GGML_TYPE_F32 ||
|
||||
mul->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// if rms_norm is the B operand, then we don't handle broadcast
|
||||
if (rms_norm == mul->src[1] &&
|
||||
!ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// rms_norm assumes contiguous rows
|
||||
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
|
||||
|
||||
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
@@ -2198,6 +2334,12 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
bool ok = ggml_cl_compute_forward(backend, node);
|
||||
if (!ok) {
|
||||
GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
@@ -2233,8 +2375,18 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
{
|
||||
// TODO: add support
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||
return false;
|
||||
} break;
|
||||
#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
|
||||
if (op->src[0]->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CONT:
|
||||
@@ -2304,6 +2456,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
op->src[0]->ne[3] == 1 && op->ne[3] == 1;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_CONV_2D:
|
||||
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
|
||||
(op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
|
||||
(op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
|
||||
case GGML_OP_CONCAT:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
@@ -3374,6 +3530,111 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
GGML_ASSERT(src1);
|
||||
GGML_ASSERT(src1->extra);
|
||||
GGML_ASSERT(dst);
|
||||
GGML_ASSERT(dst->extra);
|
||||
|
||||
// ne0 = ne00
|
||||
// ne2 = ne02
|
||||
// ne3 = ne03
|
||||
|
||||
const int ne01 = src0->ne[1];
|
||||
const int ne02 = src0->ne[2];
|
||||
const int ne03 = src0->ne[3];
|
||||
|
||||
const cl_ulong nb01 = src0->nb[1];
|
||||
const cl_ulong nb02 = src0->nb[2];
|
||||
const cl_ulong nb03 = src0->nb[3];
|
||||
|
||||
const int ne11 = src1->ne[1];
|
||||
const int ne12 = src1->ne[2];
|
||||
|
||||
const cl_ulong nb10 = src1->nb[0];
|
||||
const cl_ulong nb11 = src1->nb[1];
|
||||
const cl_ulong nb12 = src1->nb[2];
|
||||
|
||||
const int ne0 = dst->ne[0];
|
||||
|
||||
const cl_ulong nb1 = dst->nb[1];
|
||||
const cl_ulong nb2 = dst->nb[2];
|
||||
const cl_ulong nb3 = dst->nb[3];
|
||||
|
||||
const int nblk0 = ne0/ggml_blck_size(dst->type);
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
cl_kernel kernel;
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32:
|
||||
kernel = backend_ctx->kernel_set_rows_f32;
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
kernel = backend_ctx->kernel_set_rows_f16;
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("not implemented");
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &nblk0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb3));
|
||||
|
||||
int nth0 = 64;
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
nth0 = 32;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
nth0 = 64;
|
||||
}
|
||||
|
||||
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
|
||||
while (nth0 < nblk0 && nth0 < max_workgroup_size) {
|
||||
nth0 *= 2;
|
||||
}
|
||||
|
||||
int rows_per_workgroup = 1;
|
||||
if (nth0 > nblk0) {
|
||||
rows_per_workgroup = nth0 / nblk0;
|
||||
nth0 = nblk0;
|
||||
}
|
||||
|
||||
size_t global_work_size[] = {
|
||||
(size_t)(ne01 + rows_per_workgroup - 1)/rows_per_workgroup*nth0,
|
||||
(size_t)ne02*rows_per_workgroup,
|
||||
(size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth0, (size_t)rows_per_workgroup, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
@@ -4242,6 +4503,117 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) {
|
||||
GGML_ASSERT(mul_tensor);
|
||||
GGML_ASSERT(rms_norm_tensor);
|
||||
|
||||
// src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm)
|
||||
const ggml_tensor * src0 = rms_norm_tensor->src[0];
|
||||
const ggml_tensor * src1;
|
||||
if (mul_tensor->src[0] == rms_norm_tensor) {
|
||||
src1 = mul_tensor->src[1];
|
||||
} else if (mul_tensor->src[1] == rms_norm_tensor) {
|
||||
src1 = mul_tensor->src[0];
|
||||
} else {
|
||||
GGML_ASSERT(false && "Invalid args for rms_norm and mul");
|
||||
}
|
||||
const ggml_tensor * dst = mul_tensor;
|
||||
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
GGML_ASSERT(src1);
|
||||
GGML_ASSERT(src1->extra);
|
||||
GGML_ASSERT(dst);
|
||||
GGML_ASSERT(dst->extra);
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offset1 = extra1->offset + src0->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, rms_norm_tensor->op_params, sizeof(float));
|
||||
|
||||
const int ne00 = src0->ne[0];
|
||||
const int ne01 = src0->ne[1];
|
||||
const int ne02 = src0->ne[2];
|
||||
const int ne03 = src0->ne[3];
|
||||
|
||||
const cl_ulong nb01 = src0->nb[1];
|
||||
const cl_ulong nb02 = src0->nb[2];
|
||||
const cl_ulong nb03 = src0->nb[3];
|
||||
|
||||
const int ne10 = src1->ne[0];
|
||||
const int ne11 = src1->ne[1];
|
||||
const int ne12 = src1->ne[2];
|
||||
const int ne13 = src1->ne[3];
|
||||
|
||||
const cl_ulong nb11 = src1->nb[1];
|
||||
const cl_ulong nb12 = src1->nb[2];
|
||||
const cl_ulong nb13 = src1->nb[3];
|
||||
|
||||
const cl_ulong nb1 = dst->nb[1];
|
||||
const cl_ulong nb2 = dst->nb[2];
|
||||
const cl_ulong nb3 = dst->nb[3];
|
||||
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
|
||||
size_t sgs;
|
||||
if (backend_ctx->gpu_family == ADRENO) {
|
||||
sgs = 64;
|
||||
} else if (backend_ctx->gpu_family == INTEL) {
|
||||
sgs = 32;
|
||||
} else {
|
||||
GGML_ASSERT(false && "Unsupported GPU");
|
||||
}
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_rms_norm_mul;
|
||||
|
||||
int nth = sgs;
|
||||
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
|
||||
while (nth < ne00 && nth < max_workgroup_size) {
|
||||
nth *= 2;
|
||||
}
|
||||
nth = MIN(nth, max_workgroup_size);
|
||||
nth = MIN(nth, ne00);
|
||||
|
||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
|
||||
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps));
|
||||
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
@@ -4784,6 +5156,134 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
const int M = src0->ne[1];
|
||||
const int N = src1->ne[1];
|
||||
const int K = src0->ne[0];
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_mul_mat_f16_f32_tiled;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(int), &M));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &N));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &K));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd));
|
||||
|
||||
// Tiling parameters. These need to be tuned for optimal performance.
|
||||
// They must match the #defines in the kernel mul_mat_f16_f32.cl.
|
||||
//
|
||||
// OPWM / OPWN: Output tile size per Work-Group. A work-group computes a tile of size OPWM x OPWN.
|
||||
// TPWM / TPWN: Threads per Work-group. This is the work-group size.
|
||||
// OPTM / OPTN: Output elements per Thread. Each thread computes OPTM x OPTN elements.
|
||||
//
|
||||
// The following relationships must hold:
|
||||
// OPWM = TPWM * OPTM
|
||||
// OPWN = TPWN * OPTN
|
||||
//
|
||||
const int OPWM = 64;
|
||||
const int OPWN = 64;
|
||||
const int TPWM = 16;
|
||||
const int TPWN = 8;
|
||||
|
||||
size_t local_work_size[2] = { TPWM, TPWN };
|
||||
size_t global_work_size[2] = {
|
||||
(size_t) ((M + OPWM - 1) / OPWM) * TPWM,
|
||||
(size_t) ((N + OPWN - 1) / OPWN) * TPWN,
|
||||
};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13;
|
||||
const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1;
|
||||
|
||||
const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1];
|
||||
const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3];
|
||||
const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5];
|
||||
|
||||
const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type);
|
||||
const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type);
|
||||
const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type);
|
||||
|
||||
const int64_t NPQ = (int64_t)N * OW * OH;
|
||||
|
||||
const uint32_t BS_K = 64;
|
||||
const uint32_t BS_NPQ = 64;
|
||||
const uint32_t BS_CRS = 16;
|
||||
const uint32_t VEC_SIZE = 4;
|
||||
|
||||
const uint32_t TS_K = 4;
|
||||
const uint32_t TS_NPQ = 8;
|
||||
|
||||
const uint32_t WG_K = BS_K / TS_K;
|
||||
const uint32_t WG_NPQ = BS_NPQ / TS_NPQ;
|
||||
|
||||
auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; };
|
||||
const uint32_t NB_K = splitWork(Cout, BS_K);
|
||||
const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ);
|
||||
|
||||
cl_kernel kernel;
|
||||
size_t shmem_size;
|
||||
|
||||
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
kernel = backend_ctx->kernel_conv_2d_f16;
|
||||
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4));
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_conv_2d_f32;
|
||||
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_conv_2d_f16_f32;
|
||||
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
|
||||
} else {
|
||||
GGML_ASSERT(false && "Unsupported data type combination for conv2d");
|
||||
}
|
||||
|
||||
cl_uint idx = 0;
|
||||
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL));
|
||||
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N));
|
||||
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H));
|
||||
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH));
|
||||
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1));
|
||||
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1));
|
||||
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13));
|
||||
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3));
|
||||
|
||||
size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 };
|
||||
size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 };
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
@@ -4797,6 +5297,18 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
|
||||
src0->ne[1] > 32 && // M > 32
|
||||
src1->ne[1] > 32 && // N > 32
|
||||
src0->ne[0] > 32 && // K > 32
|
||||
src0->ne[2] == 1 && src0->ne[3] == 1 &&
|
||||
src1->ne[2] == 1 && src1->ne[3] == 1 &&
|
||||
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
|
||||
backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
|
||||
ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
@@ -5587,7 +6099,9 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
float scale;
|
||||
memcpy(&scale, dst->op_params, sizeof(scale));
|
||||
float bias;
|
||||
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
||||
memcpy(&bias, ((int32_t *) dst->op_params) + 1, sizeof(float));
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
@@ -5602,6 +6116,7 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias));
|
||||
|
||||
int n = ggml_nelements(dst)/4;
|
||||
|
||||
@@ -6385,6 +6900,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
||||
}
|
||||
func = ggml_cl_get_rows;
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cl_set_rows;
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
@@ -6517,6 +7038,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
||||
}
|
||||
ggml_cl_upscale(backend, tensor->src[0], tensor);
|
||||
return true;
|
||||
case GGML_OP_CONV_2D:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cl_conv_2d;
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
|
||||
185
ggml/src/ggml-opencl/kernels/conv2d.cl
Normal file
185
ggml/src/ggml-opencl/kernels/conv2d.cl
Normal file
@@ -0,0 +1,185 @@
|
||||
#ifdef USE_FP16
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#define T_FLOAT half
|
||||
#define T_FLOAT4 half4
|
||||
#define VSTORE_T_FLOAT4(data, offset, p) vstore_half4_rte(data, offset, p)
|
||||
#else
|
||||
#define T_FLOAT float
|
||||
#define T_FLOAT4 float4
|
||||
#define VSTORE_T_FLOAT4(data, offset, p) vstore4(data, offset, p)
|
||||
#endif
|
||||
|
||||
#if defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#else
|
||||
#define REQD_SUBGROUP_SIZE_128
|
||||
#endif
|
||||
|
||||
#define T_ACCUM float4
|
||||
#define VEC_SIZE 4
|
||||
|
||||
#define BS_K 64
|
||||
#define BS_NPQ 64
|
||||
#define BS_CRS 16
|
||||
|
||||
#define TS_K 4
|
||||
#define TS_NPQ 8
|
||||
|
||||
#define WG_K (BS_K / TS_K)
|
||||
#define WG_NPQ (BS_NPQ / TS_NPQ)
|
||||
|
||||
#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
|
||||
#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
|
||||
|
||||
static inline uint splitWork(uint work_size, uint block_size){
|
||||
return (work_size + block_size - 1) / block_size;
|
||||
}
|
||||
|
||||
REQD_SUBGROUP_SIZE_128
|
||||
kernel void kernel_conv_2d(
|
||||
global void* p_knl,
|
||||
ulong off_knl,
|
||||
global void* p_src,
|
||||
ulong off_src,
|
||||
global void* p_dst,
|
||||
ulong off_dst,
|
||||
local void* shared,
|
||||
uint Cout, uint Cin, uint N,
|
||||
uint KW, uint KH, uint W, uint H, uint OW, uint OH,
|
||||
uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
|
||||
uint nb01, uint nb02, uint nb03,
|
||||
uint nb11, uint nb12, uint nb13,
|
||||
uint nb1, uint nb2, uint nb3
|
||||
) {
|
||||
global T_FLOAT* knl_data = (global T_FLOAT*) ((global char*)p_knl + off_knl);
|
||||
global T_FLOAT* src_data = (global T_FLOAT*) ((global char*)p_src + off_src);
|
||||
global T_FLOAT* dst_data = (global T_FLOAT*) ((global char*)p_dst + off_dst);
|
||||
|
||||
const uint K = Cout;
|
||||
const uint CRS = Cin*KH*KW;
|
||||
const uint NPQ = N*OH*OW;
|
||||
|
||||
const uint lid_k = get_local_id(0);
|
||||
const uint lid_npq = get_local_id(1);
|
||||
const uint tid = lid_npq * WG_K + lid_k;
|
||||
|
||||
const uint B_idx_K = get_group_id(0);
|
||||
const uint B_idx_NPQ = get_group_id(1);
|
||||
|
||||
const uint offset_k = B_idx_K * BS_K;
|
||||
const uint offset_npq = B_idx_NPQ * BS_NPQ;
|
||||
|
||||
local T_FLOAT* Ash = (local T_FLOAT*)shared;
|
||||
local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * BS_CRS];
|
||||
|
||||
T_ACCUM regC[TS_K][TS_NPQ_VEC];
|
||||
for (int i = 0; i < TS_K; ++i) {
|
||||
for (int j = 0; j < TS_NPQ_VEC; ++j) {
|
||||
regC[i][j] = (T_ACCUM)(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
const uint NB_CRS = splitWork(CRS, BS_CRS);
|
||||
|
||||
for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
|
||||
const uint offset_crs = B_idx_CRS * BS_CRS;
|
||||
|
||||
for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
|
||||
const uint k_l = i / BS_CRS;
|
||||
const uint crs_l = i % BS_CRS;
|
||||
const uint k_g = offset_k + k_l;
|
||||
const uint crs_g = offset_crs + crs_l;
|
||||
|
||||
if (k_g < K && crs_g < CRS) {
|
||||
const uint Cin_idx = crs_g / (KW*KH);
|
||||
const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
|
||||
const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
|
||||
const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
|
||||
Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
|
||||
} else {
|
||||
Ash[k_l * BS_CRS + crs_l] = (T_FLOAT)0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
|
||||
const uint crs_l = i / BS_NPQ_VEC;
|
||||
const uint npq_l_vec = i % BS_NPQ_VEC;
|
||||
const uint crs_g = offset_crs + crs_l;
|
||||
|
||||
T_FLOAT4 val = (T_FLOAT4)(0.0f);
|
||||
if (crs_g < CRS) {
|
||||
const uint Cin_idx = crs_g / (KW * KH);
|
||||
const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
|
||||
const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
|
||||
for (int v = 0; v < VEC_SIZE; ++v) {
|
||||
const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
|
||||
if (npq_g < NPQ) {
|
||||
const uint N_idx = npq_g / (OH * OW);
|
||||
const uint pq_idx = npq_g % (OH * OW);
|
||||
const uint OH_idx = pq_idx / OW;
|
||||
const uint OW_idx = pq_idx % OW;
|
||||
const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
|
||||
const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
|
||||
|
||||
if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
|
||||
const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
|
||||
((T_FLOAT*)&val)[v] = src_data[src_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
#pragma unroll
|
||||
for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
|
||||
T_FLOAT regA[TS_K];
|
||||
for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
|
||||
regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
|
||||
}
|
||||
|
||||
for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
|
||||
T_FLOAT4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
|
||||
for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
|
||||
regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
|
||||
const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
|
||||
if (k_g >= K) continue;
|
||||
|
||||
for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
|
||||
const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
|
||||
|
||||
const uint N_idx = npq_g_base / (OH * OW);
|
||||
const uint pq_idx = npq_g_base % (OH * OW);
|
||||
const uint OH_idx = pq_idx / OW;
|
||||
const uint OW_idx = pq_idx % OW;
|
||||
|
||||
if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
|
||||
const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
|
||||
VSTORE_T_FLOAT4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
|
||||
} else {
|
||||
T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
|
||||
for (int v = 0; v < VEC_SIZE; ++v) {
|
||||
const uint npq_g = npq_g_base + v;
|
||||
if (npq_g < NPQ) {
|
||||
const uint N_idx_s = npq_g / (OH*OW);
|
||||
const uint pq_idx_s = npq_g % (OH*OW);
|
||||
const uint OH_idx_s = pq_idx_s / OW;
|
||||
const uint OW_idx_s = pq_idx_s % OW;
|
||||
const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
|
||||
dst_data[dst_idx_s] = (T_FLOAT)(((float*)&res)[v]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
176
ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl
Normal file
176
ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl
Normal file
@@ -0,0 +1,176 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#if defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#else
|
||||
#define REQD_SUBGROUP_SIZE_128
|
||||
#endif
|
||||
|
||||
#define T_ACCUM float4
|
||||
#define VEC_SIZE 4
|
||||
|
||||
#define BS_K 64
|
||||
#define BS_NPQ 64
|
||||
#define BS_CRS 16
|
||||
|
||||
#define TS_K 4
|
||||
#define TS_NPQ 8
|
||||
|
||||
#define WG_K (BS_K / TS_K)
|
||||
#define WG_NPQ (BS_NPQ / TS_NPQ)
|
||||
|
||||
#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
|
||||
#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
|
||||
|
||||
static inline uint splitWork(uint work_size, uint block_size){
|
||||
return (work_size + block_size - 1) / block_size;
|
||||
}
|
||||
|
||||
REQD_SUBGROUP_SIZE_128
|
||||
kernel void kernel_conv_2d(
|
||||
global void* p_knl,
|
||||
ulong off_knl,
|
||||
global void* p_src,
|
||||
ulong off_src,
|
||||
global void* p_dst,
|
||||
ulong off_dst,
|
||||
local void* shared,
|
||||
uint Cout, uint Cin, uint N,
|
||||
uint KW, uint KH, uint W, uint H, uint OW, uint OH,
|
||||
uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
|
||||
uint nb01, uint nb02, uint nb03,
|
||||
uint nb11, uint nb12, uint nb13,
|
||||
uint nb1, uint nb2, uint nb3
|
||||
) {
|
||||
global half* knl_data = (global half*) ((global char*)p_knl + off_knl);
|
||||
global float* src_data = (global float*) ((global char*)p_src + off_src);
|
||||
global float* dst_data = (global float*) ((global char*)p_dst + off_dst);
|
||||
|
||||
const uint K = Cout;
|
||||
const uint CRS = Cin*KH*KW;
|
||||
const uint NPQ = N*OH*OW;
|
||||
|
||||
const uint lid_k = get_local_id(0);
|
||||
const uint lid_npq = get_local_id(1);
|
||||
const uint tid = lid_npq * WG_K + lid_k;
|
||||
|
||||
const uint B_idx_K = get_group_id(0);
|
||||
const uint B_idx_NPQ = get_group_id(1);
|
||||
|
||||
const uint offset_k = B_idx_K * BS_K;
|
||||
const uint offset_npq = B_idx_NPQ * BS_NPQ;
|
||||
|
||||
local half* Ash = (local half*)shared;
|
||||
local float4* Bsh = (local float4*) &Ash[BS_K * BS_CRS];
|
||||
|
||||
T_ACCUM regC[TS_K][TS_NPQ_VEC];
|
||||
for (int i = 0; i < TS_K; ++i) {
|
||||
for (int j = 0; j < TS_NPQ_VEC; ++j) {
|
||||
regC[i][j] = (T_ACCUM)(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
const uint NB_CRS = splitWork(CRS, BS_CRS);
|
||||
|
||||
for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
|
||||
const uint offset_crs = B_idx_CRS * BS_CRS;
|
||||
|
||||
for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
|
||||
const uint k_l = i / BS_CRS;
|
||||
const uint crs_l = i % BS_CRS;
|
||||
const uint k_g = offset_k + k_l;
|
||||
const uint crs_g = offset_crs + crs_l;
|
||||
|
||||
if (k_g < K && crs_g < CRS) {
|
||||
const uint Cin_idx = crs_g / (KW*KH);
|
||||
const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
|
||||
const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
|
||||
const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
|
||||
Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
|
||||
} else {
|
||||
Ash[k_l * BS_CRS + crs_l] = (half)0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
|
||||
const uint crs_l = i / BS_NPQ_VEC;
|
||||
const uint npq_l_vec = i % BS_NPQ_VEC;
|
||||
const uint crs_g = offset_crs + crs_l;
|
||||
|
||||
float4 val = (float4)(0.0f);
|
||||
if (crs_g < CRS) {
|
||||
const uint Cin_idx = crs_g / (KW * KH);
|
||||
const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
|
||||
const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
|
||||
for (int v = 0; v < VEC_SIZE; ++v) {
|
||||
const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
|
||||
if (npq_g < NPQ) {
|
||||
const uint N_idx = npq_g / (OH * OW);
|
||||
const uint pq_idx = npq_g % (OH * OW);
|
||||
const uint OH_idx = pq_idx / OW;
|
||||
const uint OW_idx = pq_idx % OW;
|
||||
const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
|
||||
const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
|
||||
|
||||
if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
|
||||
const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
|
||||
((float*)&val)[v] = src_data[src_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
#pragma unroll
|
||||
for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
|
||||
half regA[TS_K];
|
||||
for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
|
||||
regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
|
||||
}
|
||||
|
||||
for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
|
||||
float4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
|
||||
for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
|
||||
regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), regB, regC[k_l_reg][npq_l_vec_reg]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
|
||||
const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
|
||||
if (k_g >= K) continue;
|
||||
|
||||
for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
|
||||
const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
|
||||
|
||||
const uint N_idx = npq_g_base / (OH * OW);
|
||||
const uint pq_idx = npq_g_base % (OH * OW);
|
||||
const uint OH_idx = pq_idx / OW;
|
||||
const uint OW_idx = pq_idx % OW;
|
||||
|
||||
if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
|
||||
const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
|
||||
vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
|
||||
} else {
|
||||
T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
|
||||
for (int v = 0; v < VEC_SIZE; ++v) {
|
||||
const uint npq_g = npq_g_base + v;
|
||||
if (npq_g < NPQ) {
|
||||
const uint N_idx_s = npq_g / (OH*OW);
|
||||
const uint pq_idx_s = npq_g % (OH*OW);
|
||||
const uint OH_idx_s = pq_idx_s / OW;
|
||||
const uint OW_idx_s = pq_idx_s % OW;
|
||||
const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
|
||||
dst_data[dst_idx_s] = ((float*)&res)[v];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -31,7 +31,7 @@ kernel void kernel_im2col_f16(
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global half*)((global char*)dst + offsetd);
|
||||
|
||||
long ksize = OW * (KH > 1 ? KW : 1);
|
||||
long ksize = OW * KH;
|
||||
long kx = i / ksize;
|
||||
long kd = kx * ksize;
|
||||
long ky = (i - kd) / OW;
|
||||
|
||||
@@ -31,7 +31,7 @@ kernel void kernel_im2col_f32(
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
long ksize = OW * (KH > 1 ? KW : 1);
|
||||
long ksize = OW * KH;
|
||||
long kx = i / ksize;
|
||||
long kd = kx * ksize;
|
||||
long ky = (i - kd) / OW;
|
||||
|
||||
130
ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl
Normal file
130
ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl
Normal file
@@ -0,0 +1,130 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#if defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#else
|
||||
#define REQD_SUBGROUP_SIZE_128
|
||||
#endif
|
||||
|
||||
#define OPWM 64
|
||||
#define OPWN 64
|
||||
#define CPWK 8
|
||||
#define OPTM 4
|
||||
#define OPTN 8
|
||||
|
||||
#define WG_M (OPWM / OPTM)
|
||||
#define WG_N (OPWN / OPTN)
|
||||
#define VEC_K (CPWK / 4)
|
||||
|
||||
REQD_SUBGROUP_SIZE_128
|
||||
__kernel void mul_mat_f16_f32(
|
||||
const int M, const int N, const int K,
|
||||
__global const void* A_void, ulong A_offset,
|
||||
__global const void* B_void, ulong B_offset,
|
||||
__global void* C_void, ulong C_offset) {
|
||||
|
||||
__global const half* A = (__global const half* )((__global const char*)A_void + A_offset);
|
||||
__global const float* B = (__global const float*)((__global const char*)B_void + B_offset);
|
||||
__global float* C = (__global float*)((__global char*)C_void + C_offset);
|
||||
|
||||
const int lidm = get_local_id(0);
|
||||
const int lidn = get_local_id(1);
|
||||
const int lid = lidn * WG_M + lidm;
|
||||
|
||||
const int offsetM = get_group_id(0) * OPWM;
|
||||
const int offsetN = get_group_id(1) * OPWN;
|
||||
|
||||
__local half4 Alocal[OPWM][VEC_K];
|
||||
__local float4 Blocal[OPWN][VEC_K];
|
||||
|
||||
float sum[OPTM][OPTN];
|
||||
|
||||
for (int wm = 0; wm < OPTM; wm++) {
|
||||
for (int wn = 0; wn < OPTN; wn++) {
|
||||
sum[wm][wn] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
const int numTiles = (K + CPWK - 1) / CPWK;
|
||||
|
||||
const int load_row_a = lid % OPWM;
|
||||
const int load_vec_k_a = lid / OPWM;
|
||||
const int global_row_a = offsetM + load_row_a;
|
||||
|
||||
const int load_row_b = lid % OPWN;
|
||||
const int load_vec_k_b = lid / OPWN;
|
||||
const int global_row_b = offsetN + load_row_b;
|
||||
|
||||
for (int t = 0; t < numTiles; t++) {
|
||||
const int k_start = t * CPWK;
|
||||
const int k_vec_start_a = k_start + load_vec_k_a * 4;
|
||||
const int k_vec_start_b = k_start + load_vec_k_b * 4;
|
||||
|
||||
if (global_row_a < M && k_vec_start_a < K) {
|
||||
if (k_vec_start_a + 3 < K) {
|
||||
Alocal[load_row_a][load_vec_k_a] = vload4(0, A + global_row_a * K + k_vec_start_a);
|
||||
} else {
|
||||
half4 tempA = (half4)(0.0h);
|
||||
if (k_vec_start_a < K) tempA.s0 = A[global_row_a * K + k_vec_start_a];
|
||||
if (k_vec_start_a + 1 < K) tempA.s1 = A[global_row_a * K + k_vec_start_a + 1];
|
||||
if (k_vec_start_a + 2 < K) tempA.s2 = A[global_row_a * K + k_vec_start_a + 2];
|
||||
Alocal[load_row_a][load_vec_k_a] = tempA;
|
||||
}
|
||||
} else {
|
||||
Alocal[load_row_a][load_vec_k_a] = (half4)(0.0h);
|
||||
}
|
||||
|
||||
if (global_row_b < N && k_vec_start_b < K) {
|
||||
if (k_vec_start_b + 3 < K) {
|
||||
Blocal[load_row_b][load_vec_k_b] = vload4(0, B + global_row_b * K + k_vec_start_b);
|
||||
} else {
|
||||
float4 tempB = (float4)(0.0f);
|
||||
if (k_vec_start_b < K) tempB.s0 = B[global_row_b * K + k_vec_start_b];
|
||||
if (k_vec_start_b + 1 < K) tempB.s1 = B[global_row_b * K + k_vec_start_b + 1];
|
||||
if (k_vec_start_b + 2 < K) tempB.s2 = B[global_row_b * K + k_vec_start_b + 2];
|
||||
Blocal[load_row_b][load_vec_k_b] = tempB;
|
||||
}
|
||||
} else {
|
||||
Blocal[load_row_b][load_vec_k_b] = (float4)(0.0f);
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
#pragma unroll
|
||||
for (int k_vec = 0; k_vec < VEC_K; k_vec++) {
|
||||
float4 a_fvecs[OPTM];
|
||||
int current_row_a = lidm;
|
||||
for (int wm = 0; wm < OPTM; wm++) {
|
||||
a_fvecs[wm] = convert_float4(Alocal[current_row_a][k_vec]);
|
||||
current_row_a += WG_M;
|
||||
}
|
||||
|
||||
float4 b_fvecs[OPTN];
|
||||
int current_row_b = lidn;
|
||||
for (int wn = 0; wn < OPTN; wn++) {
|
||||
b_fvecs[wn] = Blocal[current_row_b][k_vec];
|
||||
current_row_b += WG_N;
|
||||
}
|
||||
|
||||
for (int wm = 0; wm < OPTM; wm++) {
|
||||
for (int wn = 0; wn < OPTN; wn++) {
|
||||
sum[wm][wn] += dot(a_fvecs[wm], b_fvecs[wn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
for (int wm = 0; wm < OPTM; wm++) {
|
||||
int globalRow = offsetM + lidm + wm * WG_M;
|
||||
if (globalRow < M) {
|
||||
for (int wn = 0; wn < OPTN; wn++) {
|
||||
int globalCol = offsetN + lidn + wn * WG_N;
|
||||
if (globalCol < N) {
|
||||
C[globalCol * M + globalRow] = sum[wm][wn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user