mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-23 16:37:33 +03:00
Compare commits
37 Commits
b5581
...
compilade/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
62a9f34bae | ||
|
|
dd6495ddc9 | ||
|
|
f470bc36be | ||
|
|
8f47e25f56 | ||
|
|
201b31dc2e | ||
|
|
e21d2d4ae2 | ||
|
|
dc0623fddb | ||
|
|
87d34b381d | ||
|
|
b460d16ae8 | ||
|
|
91a8ee6a6f | ||
|
|
056eb74534 | ||
|
|
247e5c6e44 | ||
|
|
5787b5da57 | ||
|
|
228f34c9ce | ||
|
|
0974ad7a7c | ||
|
|
745aa5319b | ||
|
|
487a5e0401 | ||
|
|
d17a809ef0 | ||
|
|
1caae7fc6c | ||
|
|
669c13e0f6 | ||
|
|
146b88e8b3 | ||
|
|
7f37b6cf1e | ||
|
|
3a077146a4 | ||
|
|
d01d112abb | ||
|
|
9f47fa5792 | ||
|
|
9e31bec4fd | ||
|
|
5a8ae3053c | ||
|
|
0d3984424f | ||
|
|
3e63a58ef7 | ||
|
|
2589ad3704 | ||
|
|
482548716f | ||
|
|
3ac67535c8 | ||
|
|
0b4be4c435 | ||
|
|
e0e806f52e | ||
|
|
7e00e60ef8 | ||
|
|
ea1431b0fa | ||
|
|
3129639449 |
7
.github/labeler.yml
vendored
7
.github/labeler.yml
vendored
@@ -86,3 +86,10 @@ nix:
|
||||
embedding:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: examples/embedding/
|
||||
|
||||
Ascend NPU:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- ggml/include/ggml-cann.h
|
||||
- ggml/src/ggml-cann/**
|
||||
- docs/backend/CANN.md
|
||||
|
||||
113
.github/workflows/build-linux-cross.yml
vendored
113
.github/workflows/build-linux-cross.yml
vendored
@@ -231,3 +231,116 @@ jobs:
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
debian-13-loongarch64-cpu-cross:
|
||||
runs-on: ubuntu-24.04
|
||||
container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup LoongArch
|
||||
run: |
|
||||
rm -f /etc/apt/sources.list.d/*
|
||||
cat << EOF | tee /etc/apt/sources.list.d/debian-ports.list
|
||||
deb http://snapshot.debian.org/archive/debian/20250515T202920Z/ trixie main
|
||||
EOF
|
||||
( echo 'quiet "true";'; \
|
||||
echo 'APT::Get::Assume-Yes "true";'; \
|
||||
echo 'APT::Install-Recommends "false";'; \
|
||||
echo 'Acquire::Check-Valid-Until "false";'; \
|
||||
echo 'Acquire::Retries "5";'; \
|
||||
) > /etc/apt/apt.conf.d/99snapshot-repos
|
||||
|
||||
apt-get update
|
||||
apt-get install -y ca-certificates debian-ports-archive-keyring cmake git zip
|
||||
dpkg --add-architecture loong64
|
||||
|
||||
# Add arch-specific repositories for non-amd64 architectures
|
||||
cat << EOF | tee /etc/apt/sources.list.d/loong64-ports.list
|
||||
deb [arch=loong64] http://snapshot.debian.org/archive/debian-ports/20250515T194251Z/ sid main
|
||||
EOF
|
||||
|
||||
apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
gcc-14-loongarch64-linux-gnu \
|
||||
g++-14-loongarch64-linux-gnu
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_CURL=OFF \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_OPENMP=OFF \
|
||||
-DLLAMA_BUILD_EXAMPLES=ON \
|
||||
-DLLAMA_BUILD_TOOLS=ON \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DCMAKE_SYSTEM_NAME=Linux \
|
||||
-DCMAKE_SYSTEM_PROCESSOR=loongarch64 \
|
||||
-DCMAKE_C_COMPILER=loongarch64-linux-gnu-gcc-14 \
|
||||
-DCMAKE_CXX_COMPILER=loongarch64-linux-gnu-g++-14 \
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
-DCMAKE_FIND_ROOT_PATH=/usr/lib/loongarch64-linux-gnu \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
debian-13-loongarch64-vulkan-cross:
|
||||
runs-on: ubuntu-24.04
|
||||
container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup LoongArch
|
||||
run: |
|
||||
rm -f /etc/apt/sources.list.d/*
|
||||
cat << EOF | tee /etc/apt/sources.list.d/debian-ports.list
|
||||
deb http://snapshot.debian.org/archive/debian/20250515T202920Z/ trixie main
|
||||
EOF
|
||||
( echo 'quiet "true";'; \
|
||||
echo 'APT::Get::Assume-Yes "true";'; \
|
||||
echo 'APT::Install-Recommends "false";'; \
|
||||
echo 'Acquire::Check-Valid-Until "false";'; \
|
||||
echo 'Acquire::Retries "5";'; \
|
||||
) > /etc/apt/apt.conf.d/99snapshot-repos
|
||||
|
||||
apt-get update
|
||||
apt-get install -y ca-certificates debian-ports-archive-keyring cmake git zip
|
||||
dpkg --add-architecture loong64
|
||||
|
||||
# Add arch-specific repositories for non-amd64 architectures
|
||||
cat << EOF | tee /etc/apt/sources.list.d/loong64-ports.list
|
||||
deb [arch=loong64] http://snapshot.debian.org/archive/debian-ports/20250515T194251Z/ sid main
|
||||
EOF
|
||||
|
||||
apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
glslc \
|
||||
gcc-14-loongarch64-linux-gnu \
|
||||
g++-14-loongarch64-linux-gnu \
|
||||
libvulkan-dev:loong64
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_CURL=OFF \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_VULKAN=ON \
|
||||
-DGGML_OPENMP=OFF \
|
||||
-DLLAMA_BUILD_EXAMPLES=ON \
|
||||
-DLLAMA_BUILD_TOOLS=ON \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DCMAKE_SYSTEM_NAME=Linux \
|
||||
-DCMAKE_SYSTEM_PROCESSOR=loongarch64 \
|
||||
-DCMAKE_C_COMPILER=loongarch64-linux-gnu-gcc-14 \
|
||||
-DCMAKE_CXX_COMPILER=loongarch64-linux-gnu-g++-14 \
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
-DCMAKE_FIND_ROOT_PATH=/usr/lib/loongarch64-linux-gnu \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
8
.github/workflows/build.yml
vendored
8
.github/workflows/build.yml
vendored
@@ -839,12 +839,12 @@ jobs:
|
||||
-DGGML_CUDA=ON
|
||||
cmake --build build
|
||||
|
||||
windows-2019-cmake-cuda:
|
||||
runs-on: windows-2019
|
||||
windows-2022-cmake-cuda:
|
||||
runs-on: windows-2022
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
cuda: ['12.4', '11.7']
|
||||
cuda: ['12.4']
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -878,7 +878,7 @@ jobs:
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
|
||||
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64
|
||||
cmake -S . -B build -G "Ninja Multi-Config" ^
|
||||
-DLLAMA_BUILD_SERVER=ON ^
|
||||
-DGGML_NATIVE=OFF ^
|
||||
|
||||
17
.github/workflows/release.yml
vendored
17
.github/workflows/release.yml
vendored
@@ -131,8 +131,9 @@ jobs:
|
||||
include:
|
||||
- build: 'x64'
|
||||
os: ubuntu-22.04
|
||||
- build: 'arm64'
|
||||
os: ubuntu-22.04-arm
|
||||
# GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm
|
||||
# - build: 'arm64'
|
||||
# os: ubuntu-22.04-arm
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
@@ -159,6 +160,9 @@ jobs:
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
${{ env.CMAKE_ARGS }}
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
@@ -207,6 +211,9 @@ jobs:
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||
-DGGML_VULKAN=ON \
|
||||
${{ env.CMAKE_ARGS }}
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
@@ -373,11 +380,11 @@ jobs:
|
||||
name: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip
|
||||
|
||||
windows-cuda:
|
||||
runs-on: windows-2019
|
||||
runs-on: windows-2022
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
cuda: ['12.4', '11.7']
|
||||
cuda: ['12.4']
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -405,7 +412,7 @@ jobs:
|
||||
id: cmake_build
|
||||
shell: cmd
|
||||
run: |
|
||||
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
|
||||
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64
|
||||
cmake -S . -B build -G "Ninja Multi-Config" ^
|
||||
-DGGML_BACKEND_DL=ON ^
|
||||
-DGGML_NATIVE=OFF ^
|
||||
|
||||
2
.github/workflows/server.yml
vendored
2
.github/workflows/server.yml
vendored
@@ -180,7 +180,7 @@ jobs:
|
||||
|
||||
|
||||
server-windows:
|
||||
runs-on: windows-2019
|
||||
runs-on: windows-2022
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
@@ -159,6 +159,11 @@ if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML)
|
||||
# ... otherwise assume ggml is added by a parent CMakeLists.txt
|
||||
endif()
|
||||
|
||||
if (MINGW)
|
||||
# Target Windows 8 for PrefetchVirtualMemory
|
||||
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
|
||||
endif()
|
||||
|
||||
#
|
||||
# build the library
|
||||
#
|
||||
|
||||
4
Makefile
4
Makefile
@@ -367,7 +367,7 @@ ifdef LLAMA_SERVER_SSL
|
||||
endif
|
||||
|
||||
ifndef GGML_NO_CPU_AARCH64
|
||||
MK_CPPFLAGS += -DGGML_USE_CPU_AARCH64
|
||||
MK_CPPFLAGS += -DGGML_USE_CPU_REPACK
|
||||
endif
|
||||
|
||||
# warnings
|
||||
@@ -970,7 +970,7 @@ OBJ_GGML = \
|
||||
$(DIR_GGML)/src/ggml-threading.o \
|
||||
$(DIR_GGML)/src/ggml-cpu/ggml-cpu.o \
|
||||
$(DIR_GGML)/src/ggml-cpu/ggml-cpu_cpp.o \
|
||||
$(DIR_GGML)/src/ggml-cpu/ggml-cpu-aarch64.o \
|
||||
$(DIR_GGML)/src/ggml-cpu/repack.o \
|
||||
$(DIR_GGML)/src/ggml-cpu/ggml-cpu-hbm.o \
|
||||
$(DIR_GGML)/src/ggml-cpu/ggml-cpu-quants.o \
|
||||
$(DIR_GGML)/src/ggml-cpu/ggml-cpu-traits.o \
|
||||
|
||||
42
README.md
42
README.md
@@ -3,6 +3,7 @@
|
||||

|
||||
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://github.com/ggml-org/llama.cpp/releases)
|
||||
[](https://github.com/ggml-org/llama.cpp/actions/workflows/server.yml)
|
||||
|
||||
[Roadmap](https://github.com/users/ggerganov/projects/7) / [Project status](https://github.com/ggml-org/llama.cpp/discussions/3471) / [Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml)
|
||||
@@ -28,6 +29,30 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
|
||||
|
||||
----
|
||||
|
||||
## Quick start
|
||||
|
||||
Getting started with llama.cpp is straightforward. Here are several ways to install it on your machine:
|
||||
|
||||
- Install `llama.cpp` using [brew, nix or winget](docs/install.md)
|
||||
- Run with Docker - see our [Docker documentation](docs/docker.md)
|
||||
- Download pre-built binaries from the [releases page](https://github.com/ggml-org/llama.cpp/releases)
|
||||
- Build from source by cloning this repository - check out [our build guide](docs/build.md)
|
||||
|
||||
Once installed, you'll need a model to work with. Head to the [Obtaining and quantizing models](#obtaining-and-quantizing-models) section to learn more.
|
||||
|
||||
Example command:
|
||||
|
||||
```sh
|
||||
# Use a local model file
|
||||
llama-cli -m my_model.gguf
|
||||
|
||||
# Or download and run a model directly from Hugging Face
|
||||
llama-cli -hf ggml-org/gemma-3-1b-it-GGUF
|
||||
|
||||
# Launch OpenAI-compatible API server
|
||||
llama-server -hf ggml-org/gemma-3-1b-it-GGUF
|
||||
```
|
||||
|
||||
## Description
|
||||
|
||||
The main goal of `llama.cpp` is to enable LLM inference with minimal setup and state-of-the-art performance on a wide
|
||||
@@ -230,6 +255,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## Supported backends
|
||||
|
||||
| Backend | Target devices |
|
||||
@@ -246,16 +272,6 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
| [OpenCL](docs/backend/OPENCL.md) | Adreno GPU |
|
||||
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All |
|
||||
|
||||
## Building the project
|
||||
|
||||
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](include/llama.h).
|
||||
The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server. Possible methods for obtaining the binaries:
|
||||
|
||||
- Clone this repository and build locally, see [how to build](docs/build.md)
|
||||
- On MacOS or Linux, install `llama.cpp` via [brew, flox or nix](docs/install.md)
|
||||
- Use a Docker image, see [documentation for Docker](docs/docker.md)
|
||||
- Download pre-built binaries from [releases](https://github.com/ggml-org/llama.cpp/releases)
|
||||
|
||||
## Obtaining and quantizing models
|
||||
|
||||
The [Hugging Face](https://huggingface.co) platform hosts a [number of LLMs](https://huggingface.co/models?library=gguf&sort=trending) compatible with `llama.cpp`:
|
||||
@@ -263,7 +279,11 @@ The [Hugging Face](https://huggingface.co) platform hosts a [number of LLMs](htt
|
||||
- [Trending](https://huggingface.co/models?library=gguf&sort=trending)
|
||||
- [LLaMA](https://huggingface.co/models?sort=trending&search=llama+gguf)
|
||||
|
||||
You can either manually download the GGUF file or directly use any `llama.cpp`-compatible models from [Hugging Face](https://huggingface.co/) or other model hosting sites, such as [ModelScope](https://modelscope.cn/), by using this CLI argument: `-hf <user>/<model>[:quant]`.
|
||||
You can either manually download the GGUF file or directly use any `llama.cpp`-compatible models from [Hugging Face](https://huggingface.co/) or other model hosting sites, such as [ModelScope](https://modelscope.cn/), by using this CLI argument: `-hf <user>/<model>[:quant]`. For example:
|
||||
|
||||
```sh
|
||||
llama-cli -hf ggml-org/gemma-3-1b-it-GGUF
|
||||
```
|
||||
|
||||
By default, the CLI would download from Hugging Face, you can switch to other options with the environment variable `MODEL_ENDPOINT`. For example, you may opt to downloading model checkpoints from ModelScope or other model sharing communities by setting the environment variable, e.g. `MODEL_ENDPOINT=https://www.modelscope.cn/`.
|
||||
|
||||
|
||||
15
ci/run.sh
15
ci/run.sh
@@ -46,7 +46,20 @@ if [ ! -z ${GG_BUILD_METAL} ]; then
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_CUDA} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=native"
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON"
|
||||
|
||||
if command -v nvidia-smi >/dev/null 2>&1; then
|
||||
CUDA_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits 2>/dev/null | head -1 | tr -d '.')
|
||||
if [[ -n "$CUDA_ARCH" && "$CUDA_ARCH" =~ ^[0-9]+$ ]]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCH}"
|
||||
else
|
||||
echo "Warning: Using fallback CUDA architectures"
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DCMAKE_CUDA_ARCHITECTURES=61;70;75;80;86;89"
|
||||
fi
|
||||
else
|
||||
echo "Error: nvidia-smi not found, cannot build with CUDA"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_SYCL} ]; then
|
||||
|
||||
@@ -934,7 +934,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
return iparams;
|
||||
}
|
||||
|
||||
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
|
||||
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
|
||||
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
|
||||
params.ctx_shift = false;
|
||||
}
|
||||
@@ -1041,7 +1041,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
if (llama_model_has_decoder(model)) {
|
||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
|
||||
}
|
||||
llama_kv_self_clear(lctx);
|
||||
llama_memory_clear(llama_get_memory(lctx), true);
|
||||
llama_synchronize(lctx);
|
||||
llama_perf_context_reset(lctx);
|
||||
llama_set_warmup(lctx, false);
|
||||
|
||||
@@ -144,6 +144,8 @@ llama_tokens common_speculative_gen_draft(
|
||||
auto & smpl = spec->smpl;
|
||||
auto & prompt = spec->prompt;
|
||||
|
||||
auto * mem = llama_get_memory(ctx);
|
||||
|
||||
int reuse_i = 0;
|
||||
int reuse_n = 0;
|
||||
|
||||
@@ -173,7 +175,7 @@ llama_tokens common_speculative_gen_draft(
|
||||
result.reserve(params.n_draft);
|
||||
|
||||
if (reuse_n == 0) {
|
||||
llama_kv_self_clear(ctx);
|
||||
llama_memory_clear(mem, false);
|
||||
|
||||
prompt.clear();
|
||||
} else {
|
||||
@@ -192,14 +194,14 @@ llama_tokens common_speculative_gen_draft(
|
||||
}
|
||||
|
||||
if (reuse_i > 0) {
|
||||
llama_kv_self_seq_rm (ctx, 0, 0, reuse_i);
|
||||
llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
|
||||
llama_memory_seq_rm (mem, 0, 0, reuse_i);
|
||||
llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i);
|
||||
|
||||
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
|
||||
}
|
||||
|
||||
if (reuse_n < (int) prompt.size()) {
|
||||
llama_kv_self_seq_rm (ctx, 0, reuse_n, -1);
|
||||
llama_memory_seq_rm (mem, 0, reuse_n, -1);
|
||||
|
||||
prompt.erase(prompt.begin() + reuse_n, prompt.end());
|
||||
}
|
||||
|
||||
@@ -3709,8 +3709,7 @@ class BertModel(TextModel):
|
||||
self._try_set_pooling_type()
|
||||
|
||||
if self.cls_out_labels:
|
||||
key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch])
|
||||
self.gguf_writer.add_array(key_name, [v for k, v in sorted(self.cls_out_labels.items())])
|
||||
self.gguf_writer.add_classifier_output_labels([v for k, v in sorted(self.cls_out_labels.items())])
|
||||
|
||||
def set_vocab(self):
|
||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
- [DataType Supports](#datatype-supports)
|
||||
- [Docker](#docker)
|
||||
- [Linux](#linux)
|
||||
- [Environment variable setup](#environment-variable-setup)
|
||||
- [TODO](#todo)
|
||||
|
||||
|
||||
@@ -290,5 +291,24 @@ Authors from Peking University: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang
|
||||
|
||||
We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers from Huawei Technologies Co., Ltd for their help during the code development and pull request.
|
||||
|
||||
## Environment variable setup
|
||||
|
||||
### GGML_CANN_ASYNC_MODE
|
||||
|
||||
Enables asynchronous operator submission. Disabled by default.
|
||||
|
||||
### GGML_CANN_MEM_POOL
|
||||
|
||||
Specifies the memory pool management strategy:
|
||||
|
||||
- vmm: Utilizes a virtual memory manager pool. If hardware support for VMM is unavailable, falls back to the legacy (leg) memory pool.
|
||||
|
||||
- prio: Employs a priority queue-based memory pool management.
|
||||
- leg: Uses a fixed-size buffer pool.
|
||||
|
||||
### GGML_CANN_DISABLE_BUF_POOL_CLEAN
|
||||
|
||||
Controls automatic cleanup of the memory pool. This option is only effective when using the prio or leg memory pool strategies.
|
||||
|
||||
## TODO
|
||||
- Support more models and data types.
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
# Build llama.cpp locally
|
||||
|
||||
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](include/llama.h).
|
||||
|
||||
The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server.
|
||||
|
||||
**To get the Code:**
|
||||
|
||||
```bash
|
||||
|
||||
@@ -1,28 +1,42 @@
|
||||
# Install pre-built version of llama.cpp
|
||||
|
||||
## Homebrew
|
||||
| Install via | Windows | Mac | Linux |
|
||||
|-------------|---------|-----|-------|
|
||||
| Winget | ✅ | | |
|
||||
| Homebrew | | ✅ | ✅ |
|
||||
| MacPorts | | ✅ | |
|
||||
| Nix | | ✅ | ✅ |
|
||||
|
||||
On Mac and Linux, the homebrew package manager can be used via
|
||||
## Winget (Windows)
|
||||
|
||||
```sh
|
||||
winget install llama.cpp
|
||||
```
|
||||
|
||||
The package is automatically updated with new `llama.cpp` releases. More info: https://github.com/ggml-org/llama.cpp/issues/8188
|
||||
|
||||
## Homebrew (Mac and Linux)
|
||||
|
||||
```sh
|
||||
brew install llama.cpp
|
||||
```
|
||||
|
||||
The formula is automatically updated with new `llama.cpp` releases. More info: https://github.com/ggml-org/llama.cpp/discussions/7668
|
||||
|
||||
## MacPorts
|
||||
## MacPorts (Mac)
|
||||
|
||||
```sh
|
||||
sudo port install llama.cpp
|
||||
```
|
||||
see also: https://ports.macports.org/port/llama.cpp/details/
|
||||
|
||||
## Nix
|
||||
See also: https://ports.macports.org/port/llama.cpp/details/
|
||||
|
||||
On Mac and Linux, the Nix package manager can be used via
|
||||
## Nix (Mac and Linux)
|
||||
|
||||
```sh
|
||||
nix profile install nixpkgs#llama-cpp
|
||||
```
|
||||
|
||||
For flake enabled installs.
|
||||
|
||||
Or
|
||||
@@ -34,13 +48,3 @@ nix-env --file '<nixpkgs>' --install --attr llama-cpp
|
||||
For non-flake enabled installs.
|
||||
|
||||
This expression is automatically updated within the [nixpkgs repo](https://github.com/NixOS/nixpkgs/blob/nixos-24.05/pkgs/by-name/ll/llama-cpp/package.nix#L164).
|
||||
|
||||
## Flox
|
||||
|
||||
On Mac and Linux, Flox can be used to install llama.cpp within a Flox environment via
|
||||
|
||||
```sh
|
||||
flox install llama-cpp
|
||||
```
|
||||
|
||||
Flox follows the nixpkgs build of llama.cpp.
|
||||
|
||||
@@ -116,7 +116,7 @@ if llama_decode(context, batch) != 0 {
|
||||
}
|
||||
|
||||
for i in 1 ..< n_parallel {
|
||||
llama_kv_self_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
|
||||
llama_memory_seq_cp(llama_get_memory(context), 0, Int32(i), 0, batch.n_tokens)
|
||||
}
|
||||
|
||||
if n_parallel > 1 {
|
||||
|
||||
@@ -37,7 +37,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_kv_self_clear(ctx);
|
||||
llama_memory_clear(llama_get_memory(ctx), true);
|
||||
|
||||
// run model
|
||||
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||
@@ -236,9 +236,24 @@ int main(int argc, char ** argv) {
|
||||
LOG("\n");
|
||||
}
|
||||
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
|
||||
const uint32_t n_cls_out = llama_model_n_cls_out(model);
|
||||
std::vector<std::string> cls_out_labels;
|
||||
|
||||
for (uint32_t i = 0; i < n_cls_out; i++) {
|
||||
const char * label = llama_model_cls_label(model, i);
|
||||
const std::string label_i(label == nullptr ? "" : label);
|
||||
cls_out_labels.emplace_back(label_i.empty() ? std::to_string(i) : label_i);
|
||||
}
|
||||
|
||||
for (int j = 0; j < n_embd_count; j++) {
|
||||
// NOTE: if you change this log - update the tests in ci/run.sh
|
||||
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
|
||||
for (uint32_t i = 0; i < n_cls_out; i++) {
|
||||
// NOTE: if you change this log - update the tests in ci/run.sh
|
||||
if (n_cls_out == 1) {
|
||||
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
|
||||
} else {
|
||||
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// print the first part of the embeddings or for a single prompt, the full embedding
|
||||
|
||||
@@ -45,7 +45,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
||||
}
|
||||
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_kv_self_clear(ctx);
|
||||
llama_memory_clear(llama_get_memory(ctx), true);
|
||||
llama_set_embeddings(ctx, true);
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
@@ -102,7 +102,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
|
||||
|
||||
llama_token eos_token = llama_vocab_eos(vocab);
|
||||
|
||||
llama_kv_self_clear(ctx);
|
||||
llama_memory_clear(llama_get_memory(ctx), true);
|
||||
llama_set_embeddings(ctx, false);
|
||||
llama_set_causal_attn(ctx, true);
|
||||
|
||||
|
||||
@@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
||||
}
|
||||
|
||||
batch->logits[batch->n_tokens - 1] = true;
|
||||
llama_kv_self_clear(context);
|
||||
llama_memory_clear(llama_get_memory(context), false);
|
||||
|
||||
const auto t_pp_start = ggml_time_us();
|
||||
if (llama_decode(context, *batch) != 0) {
|
||||
@@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
||||
|
||||
LOGi("Benchmark text generation (tg)");
|
||||
|
||||
llama_kv_self_clear(context);
|
||||
llama_memory_clear(llama_get_memory(context), false);
|
||||
const auto t_tg_start = ggml_time_us();
|
||||
for (i = 0; i < tg; i++) {
|
||||
|
||||
@@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
||||
|
||||
const auto t_tg_end = ggml_time_us();
|
||||
|
||||
llama_kv_self_clear(context);
|
||||
llama_memory_clear(llama_get_memory(context), false);
|
||||
|
||||
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
|
||||
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
|
||||
@@ -448,5 +448,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
|
||||
llama_kv_self_clear(reinterpret_cast<llama_context *>(context));
|
||||
llama_memory_clear(llama_get_memory(reinterpret_cast<llama_context *>(context)), true);
|
||||
}
|
||||
|
||||
@@ -210,7 +210,7 @@ actor LlamaContext {
|
||||
}
|
||||
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
|
||||
|
||||
llama_kv_self_clear(context)
|
||||
llama_memory_clear(llama_get_memory(context), false)
|
||||
|
||||
let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000;
|
||||
|
||||
@@ -223,7 +223,7 @@ actor LlamaContext {
|
||||
|
||||
// bench text generation
|
||||
|
||||
llama_kv_self_clear(context)
|
||||
llama_memory_clear(llama_get_memory(context), false)
|
||||
|
||||
let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000;
|
||||
|
||||
@@ -242,7 +242,7 @@ actor LlamaContext {
|
||||
|
||||
let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000;
|
||||
|
||||
llama_kv_self_clear(context)
|
||||
llama_memory_clear(llama_get_memory(context), false)
|
||||
|
||||
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
|
||||
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
|
||||
@@ -292,7 +292,7 @@ actor LlamaContext {
|
||||
func clear() {
|
||||
tokens_list.removeAll()
|
||||
temporary_invalid_cchars.removeAll()
|
||||
llama_kv_self_clear(context)
|
||||
llama_memory_clear(llama_get_memory(context), true)
|
||||
}
|
||||
|
||||
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
||||
|
||||
@@ -60,6 +60,8 @@ int main(int argc, char ** argv) {
|
||||
llama_model * model = llama_init.model.get();
|
||||
llama_context * ctx = llama_init.context.get();
|
||||
|
||||
auto * mem = llama_get_memory(ctx);
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// Tokenize the prompt
|
||||
@@ -94,7 +96,7 @@ int main(int argc, char ** argv) {
|
||||
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
|
||||
|
||||
for (int s = 1; s < W + G + 1; ++s) {
|
||||
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
|
||||
llama_memory_seq_cp(mem, 0, s, -1, -1);
|
||||
}
|
||||
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
@@ -427,17 +429,17 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// KV cache management
|
||||
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
|
||||
llama_kv_self_seq_rm(ctx, -1, n_past, -1);
|
||||
llama_memory_seq_rm(mem, -1, n_past, -1);
|
||||
|
||||
if (seq_id_best != 0) {
|
||||
// if a verification token matched, we keep the best sequence and remove the rest
|
||||
// this leads to some KV cache fragmentation
|
||||
llama_kv_self_seq_keep(ctx, seq_id_best);
|
||||
llama_kv_self_seq_cp (ctx, seq_id_best, 0, -1, -1);
|
||||
llama_kv_self_seq_rm (ctx, seq_id_best, -1, -1);
|
||||
llama_memory_seq_keep(mem, seq_id_best);
|
||||
llama_memory_seq_cp (mem, seq_id_best, 0, -1, -1);
|
||||
llama_memory_seq_rm (mem, seq_id_best, -1, -1);
|
||||
|
||||
for (int s = 1; s < W + G + 1; ++s) {
|
||||
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
|
||||
llama_memory_seq_cp(mem, 0, s, -1, -1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,7 +181,7 @@ int main(int argc, char ** argv){
|
||||
|
||||
// KV cache management
|
||||
// clean the cache of draft tokens that weren't accepted
|
||||
llama_kv_self_seq_rm(ctx, 0, n_past, -1);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), 0, n_past, -1);
|
||||
|
||||
common_batch_clear(batch_tgt);
|
||||
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
|
||||
|
||||
@@ -194,6 +194,8 @@ int main(int argc, char ** argv) {
|
||||
llama_model * model = llama_init.model.get();
|
||||
llama_context * ctx = llama_init.context.get();
|
||||
|
||||
auto * mem = llama_get_memory(ctx);
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// load the prompts from an external file if there are any
|
||||
@@ -259,7 +261,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i <= n_clients; ++i) {
|
||||
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
|
||||
llama_memory_seq_cp(mem, 0, i, -1, -1);
|
||||
}
|
||||
|
||||
LOG_INF("\n");
|
||||
@@ -286,9 +288,9 @@ int main(int argc, char ** argv) {
|
||||
if (batch.n_tokens == 0) {
|
||||
// all sequences have ended - clear the entire KV cache
|
||||
for (int i = 1; i <= n_clients; ++i) {
|
||||
llama_kv_self_seq_rm(ctx, i, -1, -1);
|
||||
llama_memory_seq_rm(mem, i, -1, -1);
|
||||
// but keep the system prompt
|
||||
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
|
||||
llama_memory_seq_cp(mem, 0, i, -1, -1);
|
||||
}
|
||||
|
||||
LOG_INF("%s: clearing the KV cache\n", __func__);
|
||||
@@ -447,8 +449,8 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
||||
llama_kv_self_seq_rm(ctx, client.id + 1, -1, -1);
|
||||
llama_kv_self_seq_cp(ctx, 0, client.id + 1, -1, -1);
|
||||
llama_memory_seq_rm(mem, client.id + 1, -1, -1);
|
||||
llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1);
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
|
||||
@@ -126,6 +126,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
int n_past = 0;
|
||||
|
||||
auto * mem = llama_get_memory(ctx);
|
||||
|
||||
// fill the KV cache
|
||||
for (int i = 0; i < n_ctx; i += n_batch) {
|
||||
if (i > 0 && n_grp > 1) {
|
||||
@@ -133,10 +135,10 @@ int main(int argc, char ** argv) {
|
||||
const int ib = i/n_batch - 1;
|
||||
const int bd = n_batch_grp*(n_grp - 1);
|
||||
|
||||
llama_kv_self_seq_add(ctx, 0, n_past - n_batch, n_past, ib*bd);
|
||||
llama_kv_self_seq_div(ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
||||
llama_memory_seq_add(mem, 0, n_past - n_batch, n_past, ib*bd);
|
||||
llama_memory_seq_div(mem, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
||||
|
||||
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
|
||||
n_past = llama_memory_seq_pos_max(mem, 0) + 1;
|
||||
}
|
||||
|
||||
common_batch_clear(batch);
|
||||
@@ -166,10 +168,10 @@ int main(int argc, char ** argv) {
|
||||
|
||||
LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard);
|
||||
|
||||
llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
llama_memory_seq_rm (mem, 0, n_keep , n_keep + n_discard);
|
||||
llama_memory_seq_add(mem, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
|
||||
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
|
||||
n_past = llama_memory_seq_pos_max(mem, 0) + 1;
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
||||
@@ -195,10 +197,10 @@ int main(int argc, char ** argv) {
|
||||
if (n_discard > 0) {
|
||||
LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
|
||||
|
||||
llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
llama_memory_seq_rm (mem, 0, n_keep , n_keep + n_discard);
|
||||
llama_memory_seq_add(mem, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
|
||||
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
|
||||
n_past = llama_memory_seq_pos_max(mem, 0) + 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
|
||||
|
||||
static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_kv_self_clear(ctx);
|
||||
llama_memory_clear(llama_get_memory(ctx), false);
|
||||
|
||||
// run model
|
||||
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||
|
||||
@@ -196,7 +196,7 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
|
||||
|
||||
// erase whole kv
|
||||
llama_kv_self_clear(ctx3);
|
||||
llama_memory_clear(llama_get_memory(ctx3), true);
|
||||
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
||||
|
||||
// restore kv into seq 1
|
||||
|
||||
@@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
|
||||
auto generate = [&](const std::string & prompt) {
|
||||
std::string response;
|
||||
|
||||
const bool is_first = llama_kv_self_seq_pos_max(ctx, 0) == 0;
|
||||
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == 0;
|
||||
|
||||
// tokenize the prompt
|
||||
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
|
||||
@@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
|
||||
while (true) {
|
||||
// check if we have enough space in the context to evaluate this batch
|
||||
int n_ctx = llama_n_ctx(ctx);
|
||||
int n_ctx_used = llama_kv_self_seq_pos_max(ctx, 0);
|
||||
int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx), 0);
|
||||
if (n_ctx_used + batch.n_tokens > n_ctx) {
|
||||
printf("\033[0m\n");
|
||||
fprintf(stderr, "context size exceeded\n");
|
||||
|
||||
@@ -217,7 +217,7 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
|
||||
|
||||
llama_kv_self_seq_rm(ctx_tgt, 0, n_past, -1);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1);
|
||||
}
|
||||
|
||||
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
||||
|
||||
@@ -142,6 +142,8 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
auto * mem_tgt = llama_get_memory(ctx_tgt);
|
||||
auto * mem_dft = llama_get_memory(ctx_dft);
|
||||
|
||||
// Tokenize the prompt
|
||||
std::vector<llama_token> inp;
|
||||
@@ -420,14 +422,14 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
|
||||
|
||||
llama_kv_self_seq_keep(ctx_dft, s_keep);
|
||||
llama_kv_self_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
||||
llama_kv_self_seq_keep(ctx_dft, 0);
|
||||
llama_memory_seq_keep(mem_dft, s_keep);
|
||||
llama_memory_seq_cp (mem_dft, s_keep, 0, -1, -1);
|
||||
llama_memory_seq_keep(mem_dft, 0);
|
||||
|
||||
llama_kv_self_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
|
||||
llama_kv_self_seq_keep(ctx_tgt, s_keep);
|
||||
llama_kv_self_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
|
||||
llama_kv_self_seq_keep(ctx_tgt, 0);
|
||||
llama_memory_seq_rm (mem_tgt, s_keep, n_past_tgt, -1);
|
||||
llama_memory_seq_keep(mem_tgt, s_keep);
|
||||
llama_memory_seq_cp (mem_tgt, s_keep, 0, -1, -1);
|
||||
llama_memory_seq_keep(mem_tgt, 0);
|
||||
}
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
@@ -444,7 +446,7 @@ int main(int argc, char ** argv) {
|
||||
common_batch_clear(batch_dft);
|
||||
common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
||||
|
||||
llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
llama_memory_seq_rm(mem_dft, 0, n_past_dft, -1);
|
||||
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
||||
llama_decode(ctx_dft, batch_dft);
|
||||
|
||||
@@ -503,8 +505,8 @@ int main(int argc, char ** argv) {
|
||||
if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) {
|
||||
LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
||||
|
||||
llama_kv_self_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||
llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
||||
llama_memory_seq_rm(mem_dft, n_seq_cur, -1, -1);
|
||||
llama_memory_seq_cp(mem_dft, s, n_seq_cur, -1, -1);
|
||||
|
||||
// all previous tokens from this branch are now also part of the new branch
|
||||
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
||||
@@ -585,9 +587,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// evaluate the target model on the drafted tokens
|
||||
{
|
||||
llama_kv_self_seq_keep(ctx_tgt, 0);
|
||||
llama_memory_seq_keep(mem_tgt, 0);
|
||||
for (int s = 1; s < n_seq_dft; ++s) {
|
||||
llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||
llama_memory_seq_cp(mem_tgt, 0, s, -1, -1);
|
||||
}
|
||||
|
||||
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
||||
|
||||
@@ -105,7 +105,7 @@ message(DEBUG "GGML_NATIVE_DEFAULT : ${GGML_NATIVE_DEFAULT}")
|
||||
message(DEBUG "INS_ENB : ${INS_ENB}")
|
||||
|
||||
option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
|
||||
option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
|
||||
option(GGML_CPU_REPACK "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
|
||||
option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF)
|
||||
option(GGML_SSE42 "ggml: enable SSE 4.2" ${INS_ENB})
|
||||
option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
|
||||
@@ -137,7 +137,7 @@ set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC")
|
||||
|
||||
|
||||
if (WIN32)
|
||||
if (MINGW)
|
||||
set(GGML_WIN_VER "0x602" CACHE STRING "ggml: Windows version")
|
||||
endif()
|
||||
|
||||
|
||||
@@ -125,7 +125,6 @@ if (NOT MSVC)
|
||||
endif()
|
||||
|
||||
if (MINGW)
|
||||
# Target Windows 8 for PrefetchVirtualMemory
|
||||
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
|
||||
endif()
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
|
||||
#include "../include/ggml-cann.h"
|
||||
#include "../include/ggml.h"
|
||||
@@ -103,6 +104,9 @@ const ggml_cann_device_info& ggml_cann_info();
|
||||
void ggml_cann_set_device(int32_t device);
|
||||
int32_t ggml_cann_get_device();
|
||||
|
||||
std::optional<std::string> get_env(const std::string& name);
|
||||
bool parse_bool(const std::string& value);
|
||||
|
||||
/**
|
||||
* @brief Abstract base class for memory pools used by CANN.
|
||||
*/
|
||||
@@ -354,7 +358,8 @@ struct ggml_backend_cann_context {
|
||||
: device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) {
|
||||
ggml_cann_set_device(device);
|
||||
description = aclrtGetSocName();
|
||||
async_mode = (getenv("GGML_CANN_ASYNC_MODE") != nullptr);
|
||||
|
||||
bool async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
|
||||
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
|
||||
device, async_mode ? "ON" : "OFF");
|
||||
}
|
||||
|
||||
@@ -31,6 +31,8 @@
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <chrono>
|
||||
#include <unordered_set>
|
||||
#include <optional>
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
@@ -93,6 +95,26 @@ int32_t ggml_cann_get_device() {
|
||||
return id;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the value of the specified environment variable (name).
|
||||
* if not empty, return a std::string object
|
||||
*/
|
||||
std::optional<std::string> get_env(const std::string& name) {
|
||||
const char* val = std::getenv(name.c_str());
|
||||
if (!val) return std::nullopt;
|
||||
std::string res = std::string(val);
|
||||
std::transform(res.begin(), res.end(), res.begin(), ::tolower);
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Verify whether the environment variable is a valid value.
|
||||
*/
|
||||
bool parse_bool(const std::string& value) {
|
||||
std::unordered_set<std::string> valid_values = {"on", "1", "yes", "y", "enable", "true"};
|
||||
return valid_values.find(value) != valid_values.end();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Initialize the CANN device information.
|
||||
*
|
||||
@@ -214,7 +236,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
|
||||
* @param device The device ID to associate with this buffer pool.
|
||||
*/
|
||||
explicit ggml_cann_pool_buf_prio(int device) : device(device) {
|
||||
disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr;
|
||||
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -410,7 +432,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
|
||||
* @param device The device ID to associate with this buffer pool.
|
||||
*/
|
||||
explicit ggml_cann_pool_buf(int device) : device(device) {
|
||||
disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr;
|
||||
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -731,16 +753,18 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
|
||||
*/
|
||||
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
|
||||
int device) {
|
||||
bool disable_vmm = (getenv("GGML_CANN_DISABLE_VMM_POOL") != nullptr);
|
||||
if (!disable_vmm && ggml_cann_info().devices[device].vmm) {
|
||||
GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device);
|
||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
|
||||
}
|
||||
bool enable_buf_prio = (getenv("GGML_CANN_ENABLE_BUF_PRIO_POOL") != nullptr);
|
||||
if (enable_buf_prio) {
|
||||
std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or("");
|
||||
|
||||
if (mem_pool_type == "prio") {
|
||||
GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
|
||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf_prio(device));
|
||||
}
|
||||
|
||||
if (ggml_cann_info().devices[device].vmm && mem_pool_type != "leg") {
|
||||
GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device);
|
||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
|
||||
}
|
||||
|
||||
GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device);
|
||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf(device));
|
||||
}
|
||||
|
||||
@@ -1074,6 +1074,10 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
|
||||
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
||||
GGML_TABLE_END()
|
||||
|
||||
GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
|
||||
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
|
||||
GGML_TABLE_END()
|
||||
|
||||
#define NGRID_IQ1S 2048
|
||||
#define IQ1S_DELTA 0.125f
|
||||
#define IQ1M_DELTA 0.125f
|
||||
|
||||
@@ -10,14 +10,14 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
list (APPEND GGML_CPU_SOURCES
|
||||
ggml-cpu/ggml-cpu.c
|
||||
ggml-cpu/ggml-cpu.cpp
|
||||
ggml-cpu/ggml-cpu-aarch64.cpp
|
||||
ggml-cpu/ggml-cpu-aarch64.h
|
||||
ggml-cpu/ggml-cpu-hbm.cpp
|
||||
ggml-cpu/ggml-cpu-hbm.h
|
||||
ggml-cpu/ggml-cpu-quants.c
|
||||
ggml-cpu/ggml-cpu-quants.h
|
||||
ggml-cpu/ggml-cpu-traits.cpp
|
||||
ggml-cpu/ggml-cpu-traits.h
|
||||
ggml-cpu/repack.cpp
|
||||
ggml-cpu/repack.h
|
||||
ggml-cpu/hbm.cpp
|
||||
ggml-cpu/hbm.h
|
||||
ggml-cpu/quants.c
|
||||
ggml-cpu/quants.h
|
||||
ggml-cpu/traits.cpp
|
||||
ggml-cpu/traits.h
|
||||
ggml-cpu/amx/amx.cpp
|
||||
ggml-cpu/amx/amx.h
|
||||
ggml-cpu/amx/mmq.cpp
|
||||
@@ -84,6 +84,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
if (GGML_SYSTEM_ARCH STREQUAL "ARM")
|
||||
message(STATUS "ARM detected")
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
ggml-cpu/arch/arm/quants.c
|
||||
ggml-cpu/arch/arm/repack.cpp
|
||||
)
|
||||
|
||||
if (MSVC AND NOT CMAKE_C_COMPILER_ID STREQUAL "Clang")
|
||||
message(FATAL_ERROR "MSVC is not supported for ARM, use clang")
|
||||
else()
|
||||
@@ -167,6 +172,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "x86")
|
||||
message(STATUS "x86 detected")
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
ggml-cpu/arch/x86/quants.c
|
||||
ggml-cpu/arch/x86/repack.cpp
|
||||
)
|
||||
|
||||
if (MSVC)
|
||||
# instruction set detection for MSVC only
|
||||
if (GGML_NATIVE)
|
||||
@@ -302,7 +312,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
# Since multiple variants of the CPU backend may be included in the same
|
||||
# build, using set_source_files_properties() to set the arch flags is not possible
|
||||
set(GGML_CPU_FEATS_NAME ${GGML_CPU_NAME}-feats)
|
||||
add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/cpu-feats-x86.cpp)
|
||||
add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/arch/x86/cpu-feats.cpp)
|
||||
target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . .. ../include)
|
||||
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARCH_DEFINITIONS})
|
||||
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
|
||||
@@ -311,6 +321,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
|
||||
message(STATUS "PowerPC detected")
|
||||
list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/powerpc/quants.c)
|
||||
if (GGML_NATIVE)
|
||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
||||
file(READ "/proc/cpuinfo" POWER10_M)
|
||||
@@ -338,6 +349,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "loongarch64")
|
||||
message(STATUS "loongarch64 detected")
|
||||
list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/loongarch/quants.c)
|
||||
|
||||
list(APPEND ARCH_FLAGS -march=loongarch64)
|
||||
if (GGML_LASX)
|
||||
list(APPEND ARCH_FLAGS -mlasx)
|
||||
@@ -347,6 +360,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
|
||||
message(STATUS "riscv64 detected")
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
ggml-cpu/arch/riscv/quants.c
|
||||
ggml-cpu/arch/riscv/repack.cpp
|
||||
)
|
||||
if (GGML_RVV)
|
||||
if (GGML_XTHEADVECTOR)
|
||||
list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d)
|
||||
@@ -358,6 +375,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||
message(STATUS "s390x detected")
|
||||
list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c)
|
||||
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
|
||||
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
|
||||
|
||||
@@ -381,12 +399,16 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
if (GGML_VXE)
|
||||
list(APPEND ARCH_FLAGS -mvx -mzvector)
|
||||
endif()
|
||||
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm")
|
||||
message(STATUS "Wasm detected")
|
||||
list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c)
|
||||
else()
|
||||
message(STATUS "Unknown architecture")
|
||||
message(WARNING "Unknown CPU architecture. Falling back to generic implementations.")
|
||||
list(APPEND ARCH_FLAGS -DGGML_CPU_GENERIC)
|
||||
endif()
|
||||
|
||||
if (GGML_CPU_AARCH64)
|
||||
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64)
|
||||
if (GGML_CPU_REPACK)
|
||||
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_REPACK)
|
||||
endif()
|
||||
|
||||
if (GGML_CPU_KLEIDIAI)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-cpu-traits.h"
|
||||
#include "traits.h"
|
||||
|
||||
#if defined(__gnu_linux__)
|
||||
#include <sys/syscall.h>
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "mmq.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-cpu-impl.h"
|
||||
#include "ggml-cpu-quants.h"
|
||||
#include "quants.h"
|
||||
#include "ggml-quants.h"
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
|
||||
4113
ggml/src/ggml-cpu/arch/arm/quants.c
Normal file
4113
ggml/src/ggml-cpu/arch/arm/quants.c
Normal file
File diff suppressed because it is too large
Load Diff
2174
ggml/src/ggml-cpu/arch/arm/repack.cpp
Normal file
2174
ggml/src/ggml-cpu/arch/arm/repack.cpp
Normal file
File diff suppressed because it is too large
Load Diff
2638
ggml/src/ggml-cpu/arch/loongarch/quants.c
Normal file
2638
ggml/src/ggml-cpu/arch/loongarch/quants.c
Normal file
File diff suppressed because it is too large
Load Diff
2731
ggml/src/ggml-cpu/arch/powerpc/quants.c
Normal file
2731
ggml/src/ggml-cpu/arch/powerpc/quants.c
Normal file
File diff suppressed because it is too large
Load Diff
2068
ggml/src/ggml-cpu/arch/riscv/quants.c
Normal file
2068
ggml/src/ggml-cpu/arch/riscv/quants.c
Normal file
File diff suppressed because it is too large
Load Diff
396
ggml/src/ggml-cpu/arch/riscv/repack.cpp
Normal file
396
ggml/src/ggml-cpu/arch/riscv/repack.cpp
Normal file
@@ -0,0 +1,396 @@
|
||||
#define GGML_COMMON_IMPL_CPP
|
||||
#define GGML_COMMON_DECL_CPP
|
||||
#include "ggml-common.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-cpu-impl.h"
|
||||
#include "traits.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <cassert>
|
||||
#include <cstdlib> // for qsort
|
||||
#include <cstdio> // for GGML_ASSERT
|
||||
|
||||
#define GGML_CPU_CLANG_WORKAROUND
|
||||
#include "../../repack.h"
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#pragma GCC diagnostic ignored "-Woverlength-strings"
|
||||
#endif
|
||||
|
||||
#define UNUSED GGML_UNUSED
|
||||
|
||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined __riscv_v
|
||||
if (__riscv_vlenb() >= QK4_0) {
|
||||
const size_t vl = QK4_0;
|
||||
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
||||
|
||||
vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||
for (int l = 0; l < nb; l++) {
|
||||
const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0];
|
||||
const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8];
|
||||
const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16];
|
||||
const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24];
|
||||
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment constraints
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4));
|
||||
|
||||
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
|
||||
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
||||
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
|
||||
const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
|
||||
const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
|
||||
const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
|
||||
const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
|
||||
|
||||
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||
|
||||
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m));
|
||||
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||
|
||||
// vector version needs Zvfhmin extension
|
||||
const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d);
|
||||
const float b_scales[8] = {
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[0]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[1]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[2]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[3]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[4]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[5]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[6]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[7])
|
||||
};
|
||||
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
|
||||
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4);
|
||||
sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4);
|
||||
}
|
||||
__riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
#endif
|
||||
{
|
||||
float sumf[8];
|
||||
int sumi;
|
||||
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
|
||||
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
|
||||
sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
|
||||
}
|
||||
sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nr % 4 == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined __riscv_v
|
||||
if (__riscv_vlenb() >= QK4_0) {
|
||||
const size_t vl = QK4_0;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
||||
vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||
vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||
vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||
vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||
for (int l = 0; l < nb; l++) {
|
||||
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
|
||||
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
||||
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
|
||||
const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
|
||||
const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
|
||||
const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
|
||||
const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
|
||||
|
||||
// vector version needs Zvfhmin extension
|
||||
const float a_scales[4] = {
|
||||
GGML_FP16_TO_FP32(a_ptr[l].d[0]),
|
||||
GGML_FP16_TO_FP32(a_ptr[l].d[1]),
|
||||
GGML_FP16_TO_FP32(a_ptr[l].d[2]),
|
||||
GGML_FP16_TO_FP32(a_ptr[l].d[3])
|
||||
};
|
||||
const float b_scales[8] = {
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[0]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[1]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[2]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[3]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[4]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[5]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[6]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[7])
|
||||
};
|
||||
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
|
||||
|
||||
const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0];
|
||||
const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32];
|
||||
const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64];
|
||||
const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96];
|
||||
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
||||
vint16m4_t sumi_l0;
|
||||
{
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4));
|
||||
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||
|
||||
sumi_l0 = sumi_hi_m;
|
||||
}
|
||||
|
||||
{
|
||||
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0));
|
||||
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||
|
||||
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4);
|
||||
sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4);
|
||||
}
|
||||
|
||||
const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8];
|
||||
const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40];
|
||||
const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72];
|
||||
const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104];
|
||||
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
||||
vint16m4_t sumi_l1;
|
||||
{
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4));
|
||||
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||
|
||||
sumi_l1 = sumi_hi_m;
|
||||
}
|
||||
|
||||
{
|
||||
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1));
|
||||
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||
|
||||
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4);
|
||||
sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4);
|
||||
}
|
||||
|
||||
const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16];
|
||||
const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48];
|
||||
const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80];
|
||||
const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112];
|
||||
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
||||
vint16m4_t sumi_l2;
|
||||
{
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4));
|
||||
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||
|
||||
sumi_l2 = sumi_hi_m;
|
||||
}
|
||||
|
||||
{
|
||||
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2));
|
||||
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||
|
||||
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4);
|
||||
sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4);
|
||||
}
|
||||
|
||||
const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24];
|
||||
const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56];
|
||||
const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88];
|
||||
const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120];
|
||||
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
||||
vint16m4_t sumi_l3;
|
||||
{
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4));
|
||||
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||
|
||||
sumi_l3 = sumi_hi_m;
|
||||
}
|
||||
|
||||
{
|
||||
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3));
|
||||
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||
|
||||
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4);
|
||||
sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4);
|
||||
}
|
||||
}
|
||||
__riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4);
|
||||
__riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4);
|
||||
__riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4);
|
||||
__riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
||||
float sumf[4][8];
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
|
||||
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
|
||||
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
|
||||
(v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++)
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
1299
ggml/src/ggml-cpu/arch/s390/quants.c
Normal file
1299
ggml/src/ggml-cpu/arch/s390/quants.c
Normal file
File diff suppressed because it is too large
Load Diff
1480
ggml/src/ggml-cpu/arch/wasm/quants.c
Normal file
1480
ggml/src/ggml-cpu/arch/wasm/quants.c
Normal file
File diff suppressed because it is too large
Load Diff
4310
ggml/src/ggml-cpu/arch/x86/quants.c
Normal file
4310
ggml/src/ggml-cpu/arch/x86/quants.c
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpu-traits.h"
|
||||
#include "traits.h"
|
||||
#include "ggml-cpu-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml-cpu-traits.h"
|
||||
#include "ggml.h"
|
||||
|
||||
// GGML internal header
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void);
|
||||
@@ -506,3 +506,25 @@ void ggml_barrier(struct ggml_threadpool * tp);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#define GGML_DO_PRAGMA_(x) _Pragma (#x)
|
||||
#define GGML_DO_PRAGMA(x) GGML_DO_PRAGMA_(x)
|
||||
#if defined(GGML_CPU_GENERIC) || defined(__HIPCC__)
|
||||
// Note for Apple targets:
|
||||
// - clang: aliases are not supported on darwin
|
||||
// - all native kernels need to be implemented in both x86 and arm files
|
||||
// - on iOS, tvOS, and visionOS, if cmake cannot determine the target architecture, all `_generic` names are replaced by defines
|
||||
# define GGML_WEAK_ALIAS(name, alias)
|
||||
#elif defined(__GNUC__)
|
||||
// GCC/Clang on *nix
|
||||
# define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(weak name = alias) // NOLINT
|
||||
#elif defined(_MSC_VER) && defined (_WIN64)
|
||||
// MSVC
|
||||
// Note: C name mangling varies across different calling conventions
|
||||
// see https://learn.microsoft.com/en-us/cpp/build/reference/decorated-names?view=msvc-170
|
||||
# define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(comment(linker, "/alternatename:" #name "=" #alias))
|
||||
#else
|
||||
# error "Unsupported compiler for GGML_WEAK_ALIAS"
|
||||
#endif
|
||||
|
||||
#define GGML_CPU_NATIVE_IMPL(name) GGML_WEAK_ALIAS(name, name ## _generic)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,63 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
// GGML CPU internal header
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Quantization
|
||||
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
// Dot product
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -3,11 +3,11 @@
|
||||
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-cpu-traits.h"
|
||||
#include "traits.h"
|
||||
#include "ggml-cpu-impl.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-cpu-quants.h"
|
||||
#include "quants.h"
|
||||
#include "ggml-threading.h"
|
||||
#include "unary-ops.h"
|
||||
#include "binary-ops.h"
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-cpu-aarch64.h"
|
||||
#include "ggml-cpu-traits.h"
|
||||
#include "repack.h"
|
||||
#include "traits.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "amx/amx.h"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include <vector>
|
||||
|
||||
#ifdef GGML_USE_CPU_HBM
|
||||
# include "ggml-cpu-hbm.h"
|
||||
# include "hbm.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_CPU_KLEIDIAI
|
||||
@@ -51,9 +51,9 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_CPU_AARCH64
|
||||
if (ggml_backend_cpu_aarch64_buffer_type()) {
|
||||
bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
|
||||
#ifdef GGML_USE_CPU_REPACK
|
||||
if (ggml_backend_cpu_repack_buffer_type()) {
|
||||
bufts.push_back(ggml_backend_cpu_repack_buffer_type());
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -596,8 +596,8 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
||||
#ifdef GGML_USE_CPU_KLEIDIAI
|
||||
features.push_back({ "KLEIDIAI", "1" });
|
||||
#endif
|
||||
#ifdef GGML_USE_CPU_AARCH64
|
||||
features.push_back({ "AARCH64_REPACK", "1" });
|
||||
#ifdef GGML_USE_CPU_REPACK
|
||||
features.push_back({ "REPACK", "1" });
|
||||
#endif
|
||||
|
||||
features.push_back({ nullptr, nullptr });
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#include "ggml-cpu-hbm.h"
|
||||
#include "hbm.h"
|
||||
|
||||
// buffer type HBM
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-threading.h"
|
||||
#include "ggml-cpu-traits.h"
|
||||
#include "traits.h"
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
|
||||
@@ -8132,8 +8132,8 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||
#define WKV_VECTOR_SIZE 4
|
||||
#endif
|
||||
|
||||
int wkv_vector_size;
|
||||
#ifdef WKV_VECTOR_SIZE
|
||||
int wkv_vector_size;
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
wkv_vector_size = svcntw();
|
||||
#else
|
||||
@@ -8348,8 +8348,8 @@ static void ggml_compute_forward_gla_f32(
|
||||
#define GLA_VECTOR_SIZE 4
|
||||
#endif
|
||||
|
||||
int gla_vector_size;
|
||||
#ifdef GLA_VECTOR_SIZE
|
||||
int gla_vector_size;
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
gla_vector_size = svcntw();
|
||||
#else
|
||||
|
||||
1179
ggml/src/ggml-cpu/quants.c
Normal file
1179
ggml/src/ggml-cpu/quants.c
Normal file
File diff suppressed because it is too large
Load Diff
116
ggml/src/ggml-cpu/quants.h
Normal file
116
ggml/src/ggml-cpu/quants.h
Normal file
@@ -0,0 +1,116 @@
|
||||
#pragma once
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
// GGML CPU internal header
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Quantization
|
||||
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
// Dot product
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
// Generic implementation
|
||||
void quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q5_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q6_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq2_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq2_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq2_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq3_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq3_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq1_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq1_m_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq4_nl_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq4_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
#if defined(GGML_CPU_GENERIC)
|
||||
#define quantize_row_q8_0_generic quantize_row_q8_0
|
||||
#define quantize_row_q8_1_generic quantize_row_q8_1
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_q4_0_q8_0_generic ggml_vec_dot_q4_0_q8_0
|
||||
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
|
||||
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
|
||||
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
|
||||
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
||||
#define ggml_vec_dot_q3_K_q8_K_generic ggml_vec_dot_q3_K_q8_K
|
||||
#define ggml_vec_dot_q4_K_q8_K_generic ggml_vec_dot_q4_K_q8_K
|
||||
#define ggml_vec_dot_q5_K_q8_K_generic ggml_vec_dot_q5_K_q8_K
|
||||
#define ggml_vec_dot_q6_K_q8_K_generic ggml_vec_dot_q6_K_q8_K
|
||||
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
|
||||
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
|
||||
#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
|
||||
#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
|
||||
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
|
||||
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
|
||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
||||
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
1566
ggml/src/ggml-cpu/repack.cpp
Normal file
1566
ggml/src/ggml-cpu/repack.cpp
Normal file
File diff suppressed because it is too large
Load Diff
119
ggml/src/ggml-cpu/repack.h
Normal file
119
ggml/src/ggml-cpu/repack.h
Normal file
@@ -0,0 +1,119 @@
|
||||
#pragma once
|
||||
|
||||
#define GGML_COMMON_DECL_CPP
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include "traits.h"
|
||||
#include "ggml.h"
|
||||
|
||||
// GGML internal header
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void);
|
||||
|
||||
template <int K> constexpr int QK_0() {
|
||||
if constexpr (K == 4) {
|
||||
return QK4_0;
|
||||
}
|
||||
if constexpr (K == 8) {
|
||||
return QK8_0;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
template <int K, int N> struct block {
|
||||
ggml_half d[N]; // deltas for N qK_0 blocks
|
||||
int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
|
||||
};
|
||||
|
||||
// control size
|
||||
static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
|
||||
static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
|
||||
static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
|
||||
static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
|
||||
|
||||
using block_q4_0x4 = block<4, 4>;
|
||||
using block_q4_0x8 = block<4, 8>;
|
||||
using block_q8_0x4 = block<8, 4>;
|
||||
using block_q8_0x8 = block<8, 8>;
|
||||
|
||||
struct block_q4_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||
uint8_t scales[96]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qs[1024]; // 4--bit quants
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
||||
|
||||
struct block_q8_Kx4 {
|
||||
float d[4]; // delta
|
||||
int8_t qs[QK_K * 4]; // quants
|
||||
int16_t bsums[QK_K / 4]; // sum of quants in groups of 16
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding");
|
||||
|
||||
struct block_iq4_nlx4 {
|
||||
ggml_half d[4]; // deltas for 4 iq4_nl blocks
|
||||
uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
|
||||
|
||||
#if defined(__cplusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Workaround for clang:
|
||||
// clang++ complains: ``error: call to 'ggml_gemm_q4_0_4x4_q8_0' is ambiguous''
|
||||
// repro: https://godbolt.org/z/oKdeWKonM (ICE), https://godbolt.org/z/1szq6P36v (ambiguous call)
|
||||
#if defined(GGML_CPU_CLANG_WORKAROUND) || !(defined(__GNUC__) && defined(__clang__)) || defined(__HIPCC__)
|
||||
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
#endif // !defined(__clang__)
|
||||
|
||||
// Native implementations
|
||||
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
||||
#if defined(GGML_CPU_GENERIC)
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#endif
|
||||
|
||||
#if defined(__cplusplus)
|
||||
} // extern "C"
|
||||
#endif
|
||||
@@ -1,4 +1,4 @@
|
||||
#include "ggml-cpu-traits.h"
|
||||
#include "traits.h"
|
||||
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-backend.h"
|
||||
@@ -466,9 +466,6 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
}
|
||||
|
||||
// TODO: move to ggml-common.h
|
||||
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
||||
|
||||
static __device__ __forceinline__ float get_alibi_slope(
|
||||
|
||||
@@ -652,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
float KQ_max_scale[cols_per_thread];
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
|
||||
const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
|
||||
KQ_max_scale[col] = expf(KQ_max_diff);
|
||||
KQ_max[col] = KQ_max_new[col];
|
||||
|
||||
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
|
||||
|
||||
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
||||
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
|
||||
}
|
||||
|
||||
@@ -615,9 +615,8 @@ static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
|
||||
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
||||
|
||||
ggml_cuda_set_device(ctx->device);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
CUDA_CHECK(cudaMemsetAsync(ctx->dev_ptr, value, buffer->size, cudaStreamPerThread));
|
||||
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
|
||||
}
|
||||
|
||||
static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
|
||||
@@ -1144,7 +1143,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)(
|
||||
static cudaError_t ggml_cuda_cpy_tensor_2d(
|
||||
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
|
||||
const char * src_ptr = (const char *) src->data;
|
||||
char * dst_ptr = (char *) dst;
|
||||
|
||||
@@ -1427,8 +1425,6 @@ static void ggml_cuda_op_mul_mat(
|
||||
const int64_t nb2 = dst->nb[2];
|
||||
const int64_t nb3 = dst->nb[3];
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
|
||||
GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
|
||||
ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
|
||||
ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
|
||||
|
||||
@@ -1750,7 +1746,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
GGML_ASSERT(!ggml_is_transposed(src0));
|
||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
|
||||
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
|
||||
// Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
|
||||
|
||||
@@ -2425,8 +2425,6 @@ void dequantize_row_iq1_m(const block_iq1_m * GGML_RESTRICT x, float * GGML_REST
|
||||
}
|
||||
}
|
||||
|
||||
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||
|
||||
void dequantize_row_iq4_nl(const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
assert(k % QK4_NL == 0);
|
||||
const int64_t nb = k / QK4_NL;
|
||||
|
||||
@@ -149,8 +149,6 @@ typedef sycl::float2 dfloat2;
|
||||
|
||||
#define MMVQ_MAX_BATCH_SIZE 8
|
||||
|
||||
static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||
|
||||
static int g_all_sycl_device_count = -1;
|
||||
static bool g_ggml_backend_sycl_buffer_type_initialized = false;
|
||||
|
||||
|
||||
@@ -265,6 +265,17 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
|
||||
const int64_t nb = k / QK_K;
|
||||
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
@@ -530,7 +541,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
||||
case GGML_TYPE_Q5_K:
|
||||
return dequantize_row_q5_K_sycl;
|
||||
case GGML_TYPE_Q6_K:
|
||||
return dequantize_row_q6_K_sycl;
|
||||
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
return dequantize_row_q6_K_sycl_reorder;
|
||||
} else {
|
||||
return dequantize_row_q6_K_sycl;
|
||||
}
|
||||
case GGML_TYPE_IQ1_S:
|
||||
return dequantize_row_iq1_s_sycl;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
@@ -587,7 +602,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
||||
case GGML_TYPE_Q5_K:
|
||||
return dequantize_row_q5_K_sycl;
|
||||
case GGML_TYPE_Q6_K:
|
||||
return dequantize_row_q6_K_sycl;
|
||||
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
return dequantize_row_q6_K_sycl_reorder;
|
||||
} else {
|
||||
return dequantize_row_q6_K_sycl;
|
||||
}
|
||||
case GGML_TYPE_IQ1_S:
|
||||
return dequantize_row_iq1_s_sycl;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
#include "cpy.hpp"
|
||||
|
||||
#include <float.h>
|
||||
#include <string>
|
||||
|
||||
#include "dequantize.hpp"
|
||||
#include "ggml-sycl/common.hpp"
|
||||
#include "ggml-sycl/presets.hpp"
|
||||
#include "ggml.h"
|
||||
|
||||
static __dpct_inline__ int best_index_int8(int n, const int8_t * val, float x) {
|
||||
if (x <= val[0]) {
|
||||
@@ -116,6 +120,15 @@ static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||
}
|
||||
}
|
||||
|
||||
/* quantized type same copy */
|
||||
template<typename T>
|
||||
static void cpy_blck_q_q(const char * cxi, char * cdsti) {
|
||||
const T * xi = (const T *) cxi;
|
||||
T * dsti = (T *) cdsti;
|
||||
*dsti = *xi;
|
||||
}
|
||||
|
||||
|
||||
static void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
||||
float * cdstf = (float *) (cdsti);
|
||||
|
||||
@@ -311,6 +324,34 @@ template <dequantize_kernel_t dequant, int qk> static void cpy_blck_q_f32(const
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, int qk>
|
||||
static void cpy_q_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
const sycl::nd_item<3> & item_ct1) {
|
||||
const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i03 = i / (ne00 * ne01 * ne02);
|
||||
const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
|
||||
const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
|
||||
const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
|
||||
|
||||
|
||||
const int i13 = i / (ne10 * ne11 * ne12);
|
||||
const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
|
||||
const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
|
||||
const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
|
||||
const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
|
||||
|
||||
cpy_blck_q_q<T>(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
template <cpy_kernel_t cpy_blck, int qk>
|
||||
static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
@@ -322,6 +363,7 @@ static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
const int i03 = i / (ne00 * ne01 * ne02);
|
||||
const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
|
||||
@@ -615,6 +657,70 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
|
||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try {
|
||||
// Unlike other operators ggml_sycl_cpy takes 2 distinct tensors instead of a dst ggml_tensor and rely on its src field
|
||||
scope_op_debug_print scope_dbg_print(__func__, src1, /*num_src=*/0,
|
||||
@@ -632,8 +738,10 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
|
||||
|
||||
char * src0_ddc = (char *) src0->data;
|
||||
char * src1_ddc = (char *) src1->data;
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
if ((src0->type == src1->type) && (ggml_is_contiguous(src0) && ggml_is_contiguous(src1))) {
|
||||
GGML_SYCL_DEBUG("%s: memcpy path\n", __func__);
|
||||
main_stream->memcpy(src1_ddc, src0_ddc, ggml_nbytes(src0));
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_f32_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||
nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
@@ -684,6 +792,16 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
ggml_cpy_f32_iq4_nl_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
|
||||
nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
|
||||
ggml_cpy_q8_0_q8_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_Q5_0) {
|
||||
ggml_cpy_q5_0_q5_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_Q5_1) {
|
||||
ggml_cpy_q5_1_q5_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_Q4_0) {
|
||||
ggml_cpy_q4_0_q4_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_Q4_1) {
|
||||
ggml_cpy_q4_1_q4_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type),
|
||||
ggml_type_name(src1->type));
|
||||
|
||||
@@ -538,6 +538,38 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_block_q6_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> & item_ct1, int64_t n_blocks) {
|
||||
const int64_t ib = item_ct1.get_group(2);
|
||||
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t ip = tid / 32; // ip is 0 or 1
|
||||
const int64_t il = tid - 32 * ip; // 0...32
|
||||
const int64_t is = 8 * ip + il / 16;
|
||||
|
||||
const uint8_t * base_ptr = static_cast<const uint8_t *>(vx);
|
||||
const auto ql_offset = ib * (QK_K / 2);
|
||||
const auto qh_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * ib;
|
||||
const auto base_scales_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * n_blocks + (QK_K / 16) * ib;
|
||||
const auto base_d_offset = ((QK_K / 2) + (QK_K / 4) + (QK_K / 16)) * n_blocks;
|
||||
const uint8_t * ql_ptr = base_ptr + ql_offset;
|
||||
const uint8_t * qh_ptr = base_ptr + qh_offset;
|
||||
const uint8_t * scales_ptr = base_ptr + base_scales_offset;
|
||||
const ggml_half * d = (const ggml_half *) (base_ptr + base_d_offset) + ib;
|
||||
|
||||
dst_t * y = yy + ib * QK_K + 128 * ip + il;
|
||||
|
||||
const uint8_t * ql = ql_ptr + 64 * ip + il;
|
||||
const uint8_t qh = *(qh_ptr + 32 * ip + il);
|
||||
const int8_t * sc = reinterpret_cast<const int8_t *>(scales_ptr + is);
|
||||
|
||||
y[0] = *d * sc[0] * ((int8_t) ((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
|
||||
y[32] = *d * sc[2] * ((int8_t) ((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
|
||||
y[64] = *d * sc[4] * ((int8_t) ((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
||||
y[96] = *d * sc[6] * ((int8_t) ((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1,
|
||||
|
||||
@@ -354,7 +354,8 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
||||
assert(tensor->view_src->buffer->buft == buffer->buft);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
|
||||
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
|
||||
!g_ggml_sycl_disable_optimize) {
|
||||
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
||||
tensor->extra = extra;
|
||||
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
|
||||
@@ -2989,6 +2990,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
return !g_ggml_sycl_prioritize_dmmv;
|
||||
default:
|
||||
return false;
|
||||
@@ -3008,6 +3010,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -3092,6 +3095,50 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
|
||||
sycl::free(tmp_buf, *stream);
|
||||
}
|
||||
|
||||
static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(size % sizeof(block_q6_K) == 0);
|
||||
GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
|
||||
|
||||
const int nblocks = size / sizeof(block_q6_K);
|
||||
|
||||
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
||||
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
|
||||
|
||||
auto * ql_ptr = data_device;
|
||||
auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
|
||||
auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
|
||||
sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
|
||||
|
||||
stream
|
||||
->parallel_for(nblocks,
|
||||
[=](auto i) {
|
||||
const block_q6_K * x = (const block_q6_K *) tmp_buf;
|
||||
const int ib = i;
|
||||
|
||||
const uint8_t * ql = x[ib].ql;
|
||||
const uint8_t * qh = x[ib].qh;
|
||||
uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
|
||||
uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
|
||||
uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
|
||||
|
||||
for (int j = 0; j < QK_K / 2; ++j) {
|
||||
base_ql_ptr[j] = ql[j];
|
||||
}
|
||||
for (int j = 0; j < QK_K / 4; ++j) {
|
||||
base_qh_ptr[j] = qh[j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < QK_K / 16; ++j) {
|
||||
base_scales_ptr[j] = x[ib].scales[j];
|
||||
}
|
||||
|
||||
dm_ptr[ib] = x[ib].d;
|
||||
})
|
||||
.wait_and_throw();
|
||||
|
||||
sycl::free(tmp_buf, *stream);
|
||||
}
|
||||
|
||||
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
uint8_t * data_device = (uint8_t *) src0->data;
|
||||
size_t ncols = src0->ne[0];
|
||||
@@ -3105,6 +3152,9 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
case GGML_TYPE_Q4_K:
|
||||
reorder_qw_q4_k(data_device, size, 0, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
reorder_qw_q6_k(data_device, size, 0, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("reorder_qw() called with unsupported type");
|
||||
break;
|
||||
@@ -4226,6 +4276,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
ggml_type src1_type = op->src[1]->type;
|
||||
if (src0_type == src1_type && (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) && src0_type != GGML_TYPE_BF16) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
@@ -4271,6 +4324,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
|
||||
return true;
|
||||
}
|
||||
if(src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_Q8_0) {
|
||||
return true;
|
||||
}
|
||||
if(src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_Q5_0) {
|
||||
return true;
|
||||
}
|
||||
if(src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_Q5_1) {
|
||||
return true;
|
||||
}
|
||||
if(src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_Q4_0) {
|
||||
return true;
|
||||
}
|
||||
if(src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_Q4_1) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
case GGML_OP_CONCAT:
|
||||
|
||||
@@ -31,11 +31,10 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
|
||||
|
||||
float partial_sum = 0.0f;
|
||||
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
|
||||
const int ibx = row * blocks_per_row + i; // x block index
|
||||
// TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
|
||||
const int bx_offset = block_type::get_block_offset(ibx);
|
||||
const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
||||
const int ibx = row * blocks_per_row + i; // x block index
|
||||
|
||||
const auto bx_offset = block_type::get_block_offset(ibx, nblocks);
|
||||
const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
||||
// Y block index that aligns with ibx
|
||||
const int iby = i * block_type::block_to_q8_1_ratio();
|
||||
const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
|
||||
@@ -46,7 +45,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
|
||||
// x block quant index when casting the quants to int
|
||||
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
|
||||
|
||||
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs, nblocks);
|
||||
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -785,6 +784,24 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
}
|
||||
}
|
||||
|
||||
static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
||||
const int nrows, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
||||
constexpr size_t num_subgroups = 16;
|
||||
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
||||
|
||||
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
||||
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
|
||||
nd_item);
|
||||
});
|
||||
});
|
||||
}
|
||||
static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
@@ -1070,7 +1087,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
||||
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
|
||||
reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
} else {
|
||||
GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n");
|
||||
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
|
||||
@@ -14,12 +14,13 @@
|
||||
#ifndef GGML_SYCL_QUANTS_HPP
|
||||
#define GGML_SYCL_QUANTS_HPP
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "ggml-common.h"
|
||||
#include "ggml.h"
|
||||
|
||||
namespace ggml_sycl_reordered {
|
||||
|
||||
|
||||
// The reordered block moves quants (qs) and scales(d) to two
|
||||
// uniform regions of memory that is contiguous in the same tensor.
|
||||
// What this means is that instead of having:
|
||||
@@ -32,7 +33,6 @@ namespace ggml_sycl_reordered {
|
||||
|
||||
template <ggml_type type> struct block_q_t;
|
||||
|
||||
|
||||
// qk number of weights / quants in a block
|
||||
// qr number of weights in a byte (described as 'before dequantization')
|
||||
// for quantization types that has low and high bits split, qr is calculated with
|
||||
@@ -47,10 +47,12 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
|
||||
static constexpr uint32_t vdr_mmvq = 2;
|
||||
};
|
||||
|
||||
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
|
||||
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
|
||||
return { block_index * (traits::qk / traits::qr), 0 };
|
||||
}
|
||||
|
||||
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half);
|
||||
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 };
|
||||
}
|
||||
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||
@@ -64,20 +66,46 @@ template <> struct block_q_t<GGML_TYPE_Q4_K> {
|
||||
static constexpr uint32_t vdr_mmvq = 2;
|
||||
};
|
||||
|
||||
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
|
||||
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
|
||||
return { block_index * (traits::qk / traits::qr), 0 };
|
||||
}
|
||||
|
||||
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
auto nblocks = (nrows * (ncols / traits::qk));
|
||||
return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
|
||||
return { nblocks * (QK_K / 2),
|
||||
(nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) };
|
||||
}
|
||||
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||
|
||||
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
|
||||
|
||||
constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
|
||||
};
|
||||
|
||||
template <> struct block_q_t<GGML_TYPE_Q6_K> {
|
||||
struct traits {
|
||||
static constexpr uint32_t qk = QK_K;
|
||||
static constexpr uint32_t qi = QI6_K;
|
||||
static constexpr uint32_t qr = QR6_K;
|
||||
static constexpr uint32_t vdr_mmvq = 1;
|
||||
};
|
||||
|
||||
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) {
|
||||
auto low_bits_index = block_index * (traits::qk / traits::qr);
|
||||
// the index of high bits it's after all low bits
|
||||
auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4));
|
||||
return { low_bits_index, high_bits_index };
|
||||
}
|
||||
|
||||
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
auto nblocks = (nrows * (ncols / traits::qk));
|
||||
auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4);
|
||||
auto block_scales = total_qs_bytes + block_index * (QK_K / 16);
|
||||
auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16);
|
||||
return { block_scales, sb_scale };
|
||||
}
|
||||
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||
};
|
||||
} // namespace ggml_sycl_reordered
|
||||
|
||||
#endif // GGML_SYCL_QUANTS_HPP
|
||||
|
||||
@@ -284,10 +284,11 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
|
||||
return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
|
||||
}
|
||||
|
||||
__dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
|
||||
const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int /* nblocks */) {
|
||||
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
|
||||
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
|
||||
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
|
||||
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
|
||||
const sycl::half2 * q8_1_ds, const int & iqs) {
|
||||
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset.first;
|
||||
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset.first));
|
||||
int v[q4_0_traits::vdr_mmvq];
|
||||
int u[2 * q4_0_traits::vdr_mmvq];
|
||||
|
||||
@@ -346,15 +347,15 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
|
||||
using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
|
||||
using q4_k_traits = typename q4_k_block::traits;
|
||||
|
||||
float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
|
||||
const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int nblocks) {
|
||||
const int ib = ibx_offset / (QK_K / 2);
|
||||
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
|
||||
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
|
||||
const sycl::half2 * q8_1_ds, const int & iqs) {
|
||||
const int ib = ibx_offset.first / (QK_K / 2);
|
||||
|
||||
const uint8_t * base = static_cast<const uint8_t *>(vbq);
|
||||
const uint8_t * qs = base + ibx_offset;
|
||||
const int total_qs_bytes = nblocks * (QK_K / 2);
|
||||
const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE;
|
||||
const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset);
|
||||
const uint8_t * qs = base + ibx_offset.first;
|
||||
const uint8_t * scs = base + d_offset.first + ib * K_SCALE_SIZE;
|
||||
const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset.second);
|
||||
|
||||
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
|
||||
const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
|
||||
@@ -395,6 +396,66 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
|
||||
}
|
||||
};
|
||||
|
||||
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {
|
||||
static constexpr ggml_type gtype = GGML_TYPE_Q6_K;
|
||||
|
||||
using q6_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q6_K>;
|
||||
using q6_k_traits = typename q6_k_block::traits;
|
||||
|
||||
__dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u,
|
||||
const int8_t * __restrict__ scales, const float d,
|
||||
const float * __restrict__ d8) {
|
||||
float sumf = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < QR6_K; ++i) {
|
||||
const int sc = scales[4 * i];
|
||||
|
||||
const int vil = (vl >> (4 * i)) & 0x0F0F0F0F;
|
||||
|
||||
const int vih = ((vh >> (4 * i)) << 4) & 0x30303030;
|
||||
|
||||
const int vi = dpct::vectorized_binary<sycl::char4>((vil | vih), 0x20202020,
|
||||
dpct::sub_sat()); // vi = (vil | vih) - 32
|
||||
|
||||
sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
|
||||
}
|
||||
|
||||
return d * sumf;
|
||||
}
|
||||
|
||||
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
|
||||
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds,
|
||||
const int iqs) {
|
||||
const int ib = ibx_offset.first / (QK_K / 2);
|
||||
|
||||
const uint8_t * base = static_cast<const uint8_t *>(vbq);
|
||||
const uint8_t * ql = base + ibx_offset.first;
|
||||
const uint8_t * qh = base + ibx_offset.second;
|
||||
const int8_t * scales = reinterpret_cast<const int8_t *>(base + d_offset.first);
|
||||
const ggml_half * d = (const ggml_half *) (base + d_offset.second) + ib;
|
||||
|
||||
const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4);
|
||||
const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8);
|
||||
const int vh_shift = 2 * ((iqs % (QI6_K / 2)) / (QI6_K / 4));
|
||||
|
||||
const int vl = get_int_from_uint8(ql, iqs);
|
||||
const int vh = get_int_from_uint8(qh, (QI6_K / 4) * (iqs / (QI6_K / 2)) + iqs % (QI6_K / 4)) >> vh_shift;
|
||||
|
||||
const int8_t * scs = scales + scale_offset;
|
||||
|
||||
int u[QR6_K];
|
||||
float d8[QR6_K];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < QR6_K; ++i) {
|
||||
u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1);
|
||||
const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i);
|
||||
d8[i] = ds_values[0];
|
||||
}
|
||||
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8);
|
||||
}
|
||||
};
|
||||
#define VDR_Q4_0_Q8_1_MMVQ 2
|
||||
#define VDR_Q4_0_Q8_1_MMQ 4
|
||||
|
||||
|
||||
@@ -196,6 +196,7 @@ enum vk_device_architecture {
|
||||
AMD_RDNA1,
|
||||
AMD_RDNA2,
|
||||
AMD_RDNA3,
|
||||
INTEL_XE2,
|
||||
};
|
||||
|
||||
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
||||
@@ -246,6 +247,34 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
|
||||
}
|
||||
return vk_device_architecture::AMD_RDNA2;
|
||||
}
|
||||
} else if (props.vendorID == VK_VENDOR_ID_INTEL) {
|
||||
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
||||
|
||||
bool subgroup_size_control = false;
|
||||
|
||||
for (const auto& properties : ext_props) {
|
||||
if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
|
||||
subgroup_size_control = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!subgroup_size_control) {
|
||||
return vk_device_architecture::OTHER;
|
||||
}
|
||||
|
||||
vk::PhysicalDeviceProperties2 props2;
|
||||
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
||||
|
||||
props2.pNext = &subgroup_size_control_props;
|
||||
device.getProperties2(&props2);
|
||||
|
||||
if (subgroup_size_control_props.minSubgroupSize == 16) {
|
||||
// Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8.
|
||||
// Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value.
|
||||
// https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html
|
||||
// https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
|
||||
return vk_device_architecture::INTEL_XE2;
|
||||
}
|
||||
}
|
||||
return vk_device_architecture::OTHER;
|
||||
}
|
||||
@@ -396,6 +425,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_count_equal_i32;
|
||||
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
||||
vk_pipeline pipeline_timestep_embedding_f32;
|
||||
vk_pipeline pipeline_conv_transpose_1d_f32;
|
||||
vk_pipeline pipeline_pool2d_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||
@@ -444,7 +474,7 @@ struct vk_device_struct {
|
||||
// for GGML_VK_PERF_LOGGER
|
||||
std::unique_ptr<vk_perf_logger> perf_logger;
|
||||
vk::QueryPool query_pool;
|
||||
uint32_t num_queries;
|
||||
int32_t num_queries;
|
||||
|
||||
~vk_device_struct() {
|
||||
VK_LOG_DEBUG("destroy device " << name);
|
||||
@@ -706,6 +736,21 @@ struct vk_op_timestep_embedding_push_constants {
|
||||
uint32_t max_period;
|
||||
};
|
||||
|
||||
struct vk_op_conv_transpose_1d_push_constants {
|
||||
uint32_t Cout;
|
||||
uint32_t Cin;
|
||||
uint32_t K;
|
||||
uint32_t L;
|
||||
uint32_t KL;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb11;
|
||||
uint32_t nb1;
|
||||
|
||||
int32_t s0;
|
||||
};
|
||||
|
||||
struct vk_op_pool2d_push_constants {
|
||||
uint32_t IW; uint32_t IH;
|
||||
uint32_t OW; uint32_t OH;
|
||||
@@ -2726,6 +2771,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
@@ -4061,7 +4108,33 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
|
||||
return s;
|
||||
}
|
||||
|
||||
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
|
||||
template <typename T> size_t push_constant_size(const T &t) {
|
||||
static_assert(std::is_class<T>::value, "T must be a struct/class");
|
||||
GGML_UNUSED(t);
|
||||
return sizeof(T);
|
||||
}
|
||||
template <typename T> size_t push_constant_size(const std::vector<T> &t) {
|
||||
GGML_UNUSED(t);
|
||||
return sizeof(T) * t.size();
|
||||
}
|
||||
template <typename T, uint32_t N> size_t push_constant_size(const std::array<T, N> &t) {
|
||||
GGML_UNUSED(t);
|
||||
return sizeof(T) * N;
|
||||
}
|
||||
|
||||
template <typename T> const T *push_constant_data(const T &t) {
|
||||
static_assert(std::is_class<T>::value, "T must be a struct/class");
|
||||
return &t;
|
||||
}
|
||||
template <typename T> const T *push_constant_data(const std::vector<T> &t) {
|
||||
return t.data();
|
||||
}
|
||||
template <typename T, uint32_t N> const T *push_constant_data(const std::array<T, N> &t) {
|
||||
return t.data();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, const T &push_constants, std::array<uint32_t, 3> elements) {
|
||||
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
|
||||
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
|
||||
const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
|
||||
@@ -4077,7 +4150,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
|
||||
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
|
||||
ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
|
||||
|
||||
subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants);
|
||||
subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
|
||||
subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
|
||||
subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
|
||||
pipeline->layout,
|
||||
@@ -4540,7 +4613,7 @@ static void ggml_vk_matmul(
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
if (split_k == 1) {
|
||||
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4548,10 +4621,10 @@ static void ggml_vk_matmul(
|
||||
|
||||
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||
// Make sure enough workgroups get assigned for split k to work
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
|
||||
}
|
||||
|
||||
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
|
||||
@@ -4599,7 +4672,7 @@ static void ggml_vk_matmul_id(
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
|
||||
nei0, nei1, nbi1, ne11, padded_n };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as });
|
||||
}
|
||||
|
||||
static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
|
||||
@@ -4720,7 +4793,7 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
|
||||
};
|
||||
init_pushconst_fastdiv(pc);
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
|
||||
}
|
||||
|
||||
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
|
||||
@@ -4739,7 +4812,7 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 });
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
@@ -4939,7 +5012,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
} else if (qx_needs_dequant) {
|
||||
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
||||
}
|
||||
if (y_non_contig) {
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
@@ -5155,7 +5228,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
||||
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
|
||||
sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
|
||||
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
@@ -5243,7 +5316,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
|
||||
}
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z });
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z });
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
@@ -5326,7 +5399,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
||||
const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
|
||||
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
|
||||
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
@@ -5542,7 +5615,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
|
||||
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
||||
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
||||
}
|
||||
if (y_non_contig) {
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
@@ -5762,7 +5835,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
||||
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 },
|
||||
vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } },
|
||||
sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z });
|
||||
pc, { groups_x, (uint32_t)nei0, groups_z });
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
||||
@@ -6112,7 +6185,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
// there's no more than one tile of rows (i.e. workgroups_x would have been
|
||||
// one). We reuse workgroups_x to mean the number of splits, so we need to
|
||||
// cancel out the divide by wg_denoms[0].
|
||||
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
||||
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
|
||||
@@ -6121,7 +6194,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
||||
},
|
||||
pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 });
|
||||
pc2, { (uint32_t)ne1, 1, 1 });
|
||||
} else {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{
|
||||
@@ -6131,7 +6204,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
||||
},
|
||||
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
|
||||
pc, { workgroups_x, workgroups_y, workgroups_z });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6392,6 +6465,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_timestep_embedding_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_conv_transpose_1d_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_POOL_2D:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_pool2d_f32;
|
||||
@@ -6726,6 +6804,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
uint32_t half_ceil = (dim + 1) / 2;
|
||||
elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
|
||||
} break;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
{
|
||||
elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
|
||||
} break;
|
||||
case GGML_OP_POOL_2D:
|
||||
{
|
||||
const uint32_t N = dst->ne[3];
|
||||
@@ -6800,7 +6882,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
}
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
|
||||
// Empty src2 is possible in rope, but the shader needs a buffer
|
||||
vk_subbuffer subbuf_z;
|
||||
@@ -6811,26 +6893,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
}
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
} else if (op == GGML_OP_IM2COL) {
|
||||
// im2col uses only src1 and dst buffers
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
} else if (op == GGML_OP_COUNT_EQUAL) {
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
// count_equal assumes that destination buffer is initialized with zeroes
|
||||
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
} else if (use_src2) {
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
} else if (use_src1) {
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
} else {
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6999,7 +7081,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
||||
}, pc, elements);
|
||||
} else if (version == 7) {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||
@@ -7010,7 +7092,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
|
||||
}, pc, elements);
|
||||
} else {
|
||||
// shouldn't happen
|
||||
GGML_ASSERT(false);
|
||||
@@ -7147,7 +7229,7 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont
|
||||
vk_subbuffer{ d_GM, gm_offset, gm_size },
|
||||
vk_subbuffer{ d_GV, gv_offset, gv_size },
|
||||
vk_subbuffer{ d_P, p_offset, p_size },
|
||||
}, sizeof(vk_op_push_constants), &pc, elements);
|
||||
}, pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
@@ -7529,6 +7611,37 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
// src0: (K, Cout, Cin, 1) -- kernel
|
||||
// src1: (L, Cin, 1, 1) -- input
|
||||
// dst: (*, Cout, 1, 1)
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
|
||||
const int32_t s0 = dst->op_params[0];
|
||||
|
||||
vk_op_conv_transpose_1d_push_constants p{};
|
||||
p.Cout = static_cast<uint32_t>(ne01);
|
||||
p.Cin = static_cast<uint32_t>(ne02);
|
||||
p.K = static_cast<uint32_t>(ne00);
|
||||
p.L = static_cast<uint32_t>(ne10);
|
||||
p.KL = static_cast<uint32_t>(ne0);
|
||||
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
|
||||
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
|
||||
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
|
||||
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
|
||||
p.s0 = static_cast<uint32_t>(s0);
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
|
||||
const int32_t k1 = dst->op_params[1];
|
||||
@@ -8005,7 +8118,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
|
||||
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
||||
ggml_vk_ctx_begin(ctx->device, subctx);
|
||||
const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1});
|
||||
ggml_vk_ctx_end(subctx);
|
||||
|
||||
auto begin = std::chrono::high_resolution_clock::now();
|
||||
@@ -8600,6 +8713,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
@@ -8664,6 +8778,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
@@ -8835,6 +8950,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_POOL_2D:
|
||||
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
|
||||
@@ -8963,6 +9082,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
@@ -9513,8 +9633,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
if (ctx->device->query_pool) {
|
||||
ctx->device->device.destroyQueryPool(ctx->device->query_pool);
|
||||
}
|
||||
VkQueryPoolCreateInfo query_create_info = { VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO };
|
||||
query_create_info.queryType = VK_QUERY_TYPE_TIMESTAMP;
|
||||
vk::QueryPoolCreateInfo query_create_info;
|
||||
query_create_info.queryType = vk::QueryType::eTimestamp;
|
||||
query_create_info.queryCount = cgraph->n_nodes + 100;
|
||||
ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info);
|
||||
ctx->device->num_queries = query_create_info.queryCount;
|
||||
@@ -9600,7 +9720,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
|
||||
// Get the results and pass them to the logger
|
||||
std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
|
||||
ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait);
|
||||
VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results");
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (!ggml_vk_is_empty(cgraph->nodes[i])) {
|
||||
ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod));
|
||||
@@ -10024,6 +10144,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
return true;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -10170,8 +10292,9 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
||||
switch (props.vendorID) {
|
||||
case VK_VENDOR_ID_INTEL:
|
||||
// Intel drivers don't support coopmat properly yet
|
||||
return false;
|
||||
// Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost,
|
||||
// while some older hardware (ex. Arc A770) has performance regressions
|
||||
return arch == vk_device_architecture::INTEL_XE2;
|
||||
case VK_VENDOR_ID_AMD:
|
||||
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
|
||||
// Workaround for AMD proprietary driver reporting support on all GPUs
|
||||
@@ -10515,6 +10638,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
const int32_t dim = tensor->op_params[0];
|
||||
const int32_t max_period = tensor->op_params[1];
|
||||
tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
|
||||
} else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){
|
||||
const int32_t s0 = tensor->op_params[0];
|
||||
const int32_t p0 = tensor->op_params[1];
|
||||
const int32_t d0 = tensor->op_params[2];
|
||||
tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
|
||||
} else if (tensor->op == GGML_OP_POOL_2D) {
|
||||
enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
|
||||
const int32_t k0 = tensor->op_params[1];
|
||||
|
||||
98
ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp
Normal file
98
ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp
Normal file
@@ -0,0 +1,98 @@
|
||||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin]
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin]
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout]
|
||||
|
||||
layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t Cout;
|
||||
uint32_t Cin;
|
||||
uint32_t K;
|
||||
uint32_t L;
|
||||
uint32_t KL;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb11;
|
||||
uint32_t nb1;
|
||||
|
||||
int32_t s0;
|
||||
} p;
|
||||
|
||||
|
||||
uint32_t Cout_idx = gl_WorkGroupID.x;
|
||||
const uint32_t bs = gl_WorkGroupSize.x;
|
||||
uint32_t tid = gl_LocalInvocationID.x;
|
||||
// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
|
||||
uint32_t tmp_len = bs*p.s0+p.K;
|
||||
shared D_TYPE tmp[4096];
|
||||
|
||||
uint splitWork(uint workSize){
|
||||
return (bs + workSize -1) / bs;
|
||||
}
|
||||
|
||||
void main(){
|
||||
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
|
||||
uint32_t idx = i*bs+tid;
|
||||
if(idx < tmp_len){
|
||||
tmp[idx] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t L_blocks = splitWork(p.L);
|
||||
for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
|
||||
if(L_block_id > 0){
|
||||
barrier();
|
||||
// Shift values in tmp to the current processing window
|
||||
for(int i = 0; i < splitWork(tmp_len); i++){
|
||||
uint32_t idx = i*bs+tid;
|
||||
if(idx >= bs*p.s0 && idx < tmp_len){
|
||||
tmp[idx-bs*p.s0] = tmp[idx];
|
||||
tmp[idx] = 0.0;
|
||||
}else if(idx >= p.K && idx < bs*p.s0){
|
||||
tmp[idx] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
// Save contributions of the block to tmp
|
||||
uint32_t L_idx = L_block_id*bs + tid;
|
||||
for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
|
||||
D_TYPE dp = 0.0;
|
||||
for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
|
||||
A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
|
||||
if(L_idx < p.L){
|
||||
B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
|
||||
dp = fma(elemKrn, elemInp, dp);
|
||||
}
|
||||
}
|
||||
tmp[tid*p.s0 + K_idx] += dp;
|
||||
barrier();
|
||||
}
|
||||
|
||||
// Save the computed values except the last block that can have different size
|
||||
uint32_t KLb_idx = L_block_id*bs*p.s0;
|
||||
if(L_block_id < L_blocks-1){
|
||||
for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
|
||||
uint32_t sh_idx = p.s0*tid+s0_idx;
|
||||
uint32_t KL_idx = KLb_idx+sh_idx;
|
||||
if(KL_idx < p.KL){
|
||||
data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
|
||||
uint32_t idx = i*bs+tid;
|
||||
uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
|
||||
if(KL_idx < p.KL){
|
||||
data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -622,6 +622,8 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
@@ -935,6 +935,9 @@ class GGUFWriter:
|
||||
def add_eom_token_id(self, id: int) -> None:
|
||||
self.add_uint32(Keys.Tokenizer.EOM_ID, id)
|
||||
|
||||
def add_classifier_output_labels(self, labels: Sequence[str]) -> None:
|
||||
self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels)
|
||||
|
||||
# for vision models
|
||||
|
||||
def add_clip_has_vision_encoder(self, value: bool) -> None:
|
||||
|
||||
153
include/llama.h
153
include/llama.h
@@ -61,7 +61,10 @@ extern "C" {
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
struct llama_sampler;
|
||||
struct llama_kv_cache;
|
||||
|
||||
typedef struct llama_memory_i * llama_memory_t;
|
||||
|
||||
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
|
||||
|
||||
typedef int32_t llama_pos;
|
||||
typedef int32_t llama_token;
|
||||
@@ -493,9 +496,11 @@ extern "C" {
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
||||
|
||||
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
||||
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
|
||||
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
|
||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
|
||||
|
||||
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
|
||||
|
||||
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
||||
|
||||
@@ -509,6 +514,13 @@ extern "C" {
|
||||
// Get the model's RoPE frequency scaling factor
|
||||
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
||||
|
||||
// Returns the number of classifier outputs (only valid for classifier models)
|
||||
// Undefined behavior for non-classifier models
|
||||
LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model);
|
||||
|
||||
// Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
|
||||
LLAMA_API const char * llama_model_cls_label(const struct llama_model * model, uint32_t i);
|
||||
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
|
||||
|
||||
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
|
||||
@@ -609,7 +621,81 @@ extern "C" {
|
||||
int32_t il_end);
|
||||
|
||||
//
|
||||
// KV cache
|
||||
// Memory
|
||||
//
|
||||
|
||||
// Clear the memory contents
|
||||
// If data == true, the data buffers will also be cleared together with the metadata
|
||||
LLAMA_API void llama_memory_clear(
|
||||
llama_memory_t mem,
|
||||
bool data);
|
||||
|
||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||
// seq_id < 0 : match any sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API bool llama_memory_seq_rm(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
|
||||
// Copy all tokens that belong to the specified sequence to another sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_memory_seq_cp(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
|
||||
// Removes all tokens that do not belong to the specified sequence
|
||||
LLAMA_API void llama_memory_seq_keep(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_memory_seq_add(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta);
|
||||
|
||||
// Integer division of the positions by factor of `d > 1`
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_memory_seq_div(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d);
|
||||
|
||||
// Returns the smallest position present in the memory for the specified sequence
|
||||
// This is typically non-zero only for SWA caches
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
|
||||
// Return -1 if the sequence is empty
|
||||
LLAMA_API llama_pos llama_memory_seq_pos_min(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Returns the largest position present in the memory for the specified sequence
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
|
||||
// Return -1 if the sequence is empty
|
||||
LLAMA_API llama_pos llama_memory_seq_pos_max(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Check if the memory supports shifting
|
||||
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
|
||||
|
||||
//
|
||||
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
|
||||
//
|
||||
|
||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||
@@ -622,86 +708,95 @@ extern "C" {
|
||||
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
|
||||
|
||||
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
||||
LLAMA_API void llama_kv_self_clear(
|
||||
struct llama_context * ctx);
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_clear(
|
||||
struct llama_context * ctx),
|
||||
"Use llama_memory_clear() instead");
|
||||
|
||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||
// seq_id < 0 : match any sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API bool llama_kv_self_seq_rm(
|
||||
DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
llama_pos p1),
|
||||
"Use llama_memory_seq_rm() instead");
|
||||
|
||||
// Copy all tokens that belong to the specified sequence to another sequence
|
||||
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_self_seq_cp(
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_cp(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
llama_pos p1),
|
||||
"Use llama_memory_seq_cp() instead");
|
||||
|
||||
// Removes all tokens that do not belong to the specified sequence
|
||||
LLAMA_API void llama_kv_self_seq_keep(
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_keep(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
llama_seq_id seq_id),
|
||||
"Use llama_memory_seq_keep() instead");
|
||||
|
||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_self_seq_add(
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_add(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta);
|
||||
llama_pos delta),
|
||||
"Use llama_memory_seq_add() instead");
|
||||
|
||||
// Integer division of the positions by factor of `d > 1`
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_self_seq_div(
|
||||
DEPRECATED(void llama_kv_self_seq_div(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d);
|
||||
int d),
|
||||
"Use llama_memory_seq_div() instead");
|
||||
|
||||
// Returns the smallest position present in the KV cache for the specified sequence
|
||||
// This is typically non-zero only for SWA caches
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
||||
// Return -1 if the sequence is empty
|
||||
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
||||
DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
llama_seq_id seq_id),
|
||||
"Use llama_memory_seq_pos_min() instead");
|
||||
|
||||
// Returns the largest position present in the KV cache for the specified sequence
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
||||
// Return -1 if the sequence is empty
|
||||
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
||||
DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
llama_seq_id seq_id),
|
||||
"Use llama_memory_seq_pos_max() instead");
|
||||
|
||||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
|
||||
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
|
||||
|
||||
// Check if the context supports KV cache shifting
|
||||
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
||||
DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx),
|
||||
"use llama_memory_can_shift() instead");
|
||||
|
||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||
LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
|
||||
"simply remove this call, updates are applied lazily on the next llama_decode()");
|
||||
|
||||
//
|
||||
@@ -709,7 +804,7 @@ extern "C" {
|
||||
//
|
||||
|
||||
// Returns the *actual* size in bytes of the state
|
||||
// (logits, embedding and kv_cache)
|
||||
// (logits, embedding and memory)
|
||||
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
||||
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
||||
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
||||
@@ -765,12 +860,12 @@ extern "C" {
|
||||
size_t n_token_count),
|
||||
"use llama_state_save_file instead");
|
||||
|
||||
// Get the exact size needed to copy the KV cache of a single sequence
|
||||
// Get the exact size needed to copy the state of a single sequence
|
||||
LLAMA_API size_t llama_state_seq_get_size(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Copy the KV cache of a single sequence into the specified buffer
|
||||
// Copy the state of a single sequence into the specified buffer
|
||||
LLAMA_API size_t llama_state_seq_get_data(
|
||||
struct llama_context * ctx,
|
||||
uint8_t * dst,
|
||||
@@ -836,16 +931,16 @@ extern "C" {
|
||||
// For encode-decoder contexts, processes the batch using the encoder.
|
||||
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
|
||||
// 0 - success
|
||||
// < 0 - error. the KV cache state is restored to the state before this call
|
||||
// < 0 - error. the memory state is restored to the state before this call
|
||||
LLAMA_API int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
// Process a batch of tokens.
|
||||
// Requires KV cache.
|
||||
// Requires the context to have a memory.
|
||||
// For encode-decoder contexts, processes the batch using the decoder.
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// Upon non-zero return values, the KV cache state is restored to the state before this call
|
||||
// Upon non-zero return values, the memory state is restored to the state before this call
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
// 2 - aborted
|
||||
@@ -916,7 +1011,7 @@ extern "C" {
|
||||
|
||||
// Get the embeddings for a sequence id
|
||||
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
||||
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
|
||||
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence
|
||||
// otherwise: float[n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ add_library(llama
|
||||
llama-hparams.cpp
|
||||
llama-impl.cpp
|
||||
llama-io.cpp
|
||||
llama-kv-cache.cpp
|
||||
llama-kv-cache-unified.cpp
|
||||
llama-kv-cache-unified-iswa.cpp
|
||||
llama-kv-cache-recurrent.cpp
|
||||
|
||||
@@ -200,7 +200,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
|
||||
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
|
||||
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
|
||||
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
|
||||
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
|
||||
@@ -1707,8 +1706,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
||||
|
||||
std::string LLM_KV::operator()(llm_kv kv) const {
|
||||
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
|
||||
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
||||
std::string name = ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
||||
|
||||
if (suffix != nullptr) {
|
||||
name += ".";
|
||||
name += suffix;
|
||||
}
|
||||
|
||||
return name;
|
||||
}
|
||||
|
||||
std::string LLM_TN_IMPL::str() const {
|
||||
|
||||
@@ -196,7 +196,6 @@ enum llm_kv {
|
||||
LLM_KV_TOKENIZER_HF_JSON,
|
||||
LLM_KV_TOKENIZER_RWKV,
|
||||
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
|
||||
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
|
||||
LLM_KV_TOKENIZER_FIM_PRE_ID,
|
||||
LLM_KV_TOKENIZER_FIM_SUF_ID,
|
||||
LLM_KV_TOKENIZER_FIM_MID_ID,
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-memory.h"
|
||||
#include "llama-mmap.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cstring>
|
||||
@@ -123,7 +123,7 @@ llama_context::llama_context(
|
||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
||||
}
|
||||
|
||||
if (!params.swa_full && cparams.n_seq_max > 1) {
|
||||
if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
|
||||
LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
|
||||
__func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
|
||||
}
|
||||
@@ -277,10 +277,9 @@ llama_context::llama_context(
|
||||
int n_nodes_tg = -1;
|
||||
|
||||
// simulate full KV cache
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
const auto kv_state = kv_self->init_full();
|
||||
if (!kv_state) {
|
||||
const auto mstate = memory->init_full();
|
||||
if (!mstate) {
|
||||
throw std::runtime_error("failed to initialize KV cache");
|
||||
}
|
||||
|
||||
@@ -288,7 +287,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve pp graph first so that buffers are only allocated once
|
||||
{
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
@@ -299,7 +298,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
{
|
||||
auto * gf = graph_reserve(1, 1, 1, kv_state.get());
|
||||
auto * gf = graph_reserve(1, 1, 1, mstate.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||
}
|
||||
@@ -310,7 +309,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
{
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
@@ -419,40 +418,68 @@ uint32_t llama_context::n_threads_batch() const {
|
||||
return cparams.n_threads_batch;
|
||||
}
|
||||
|
||||
llama_kv_cache * llama_context::get_kv_self() {
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
return kv_self;
|
||||
llama_memory_t llama_context::get_memory() const {
|
||||
return memory.get();
|
||||
}
|
||||
|
||||
const llama_kv_cache * llama_context::get_kv_self() const {
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
return kv_self;
|
||||
// deprecated
|
||||
void llama_context::kv_self_defrag_sched() {
|
||||
if (!memory) {
|
||||
return;
|
||||
}
|
||||
|
||||
memory_force_optimize = true;
|
||||
}
|
||||
|
||||
bool llama_context::kv_self_update() {
|
||||
// deprecated
|
||||
bool llama_context::kv_self_update(bool optimize) {
|
||||
if (!memory) {
|
||||
return false;
|
||||
}
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
{
|
||||
// TODO: remove in the future
|
||||
optimize |= memory_force_optimize;
|
||||
memory_force_optimize = false;
|
||||
|
||||
if (!kv_self->update(*this)) {
|
||||
// no updates have been performed
|
||||
return false;
|
||||
const auto mstate = memory->init_update(this, optimize);
|
||||
switch (mstate->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
// noop
|
||||
} break;
|
||||
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||
{
|
||||
// no updates need to be performed
|
||||
return false;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!mstate->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// if the KV cache did any computation, we have to reserve a new worst-case graph
|
||||
const auto kv_state = kv_self->init_full();
|
||||
if (!kv_state) {
|
||||
throw std::runtime_error("failed to initialize KV cache");
|
||||
}
|
||||
// if the memory module did any computation, we have to reserve a new worst-case graph
|
||||
{
|
||||
const auto mstate = memory->init_full();
|
||||
if (!mstate) {
|
||||
throw std::runtime_error("failed to initialize memory state");
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -814,16 +841,17 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract the rerank score - a single float per sequence
|
||||
// extract the rerank score - n_cls_out floats per sequence
|
||||
auto & embd_seq_out = embd_seq;
|
||||
const uint32_t n_cls_out = hparams.n_cls_out;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(1);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
||||
embd_seq_out[seq_id].resize(n_cls_out);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
@@ -880,10 +908,8 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
}
|
||||
}
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
|
||||
@@ -940,42 +966,49 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
n_outputs_all = 1;
|
||||
}
|
||||
|
||||
bool did_optimize = false;
|
||||
|
||||
// handle any pending defrags/shifts
|
||||
kv_self_update();
|
||||
kv_self_update(false);
|
||||
|
||||
llama_memory_state_ptr kv_state;
|
||||
|
||||
bool did_defrag = false;
|
||||
llama_memory_state_ptr mstate;
|
||||
|
||||
while (true) {
|
||||
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
||||
if (!kv_state) {
|
||||
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
||||
if (!mstate) {
|
||||
return -2;
|
||||
}
|
||||
|
||||
switch (kv_state->get_status()) {
|
||||
switch (mstate->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
|
||||
|
||||
return -2;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||
{
|
||||
if (!did_defrag) {
|
||||
did_defrag = true;
|
||||
if (!did_optimize) {
|
||||
did_optimize = true;
|
||||
|
||||
kv_self->defrag_sched(-1.0f);
|
||||
if (kv_self_update()) {
|
||||
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
|
||||
if (kv_self_update(true)) {
|
||||
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
||||
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
|
||||
|
||||
return 1;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
|
||||
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
@@ -992,7 +1025,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
int64_t n_outputs_prev = 0;
|
||||
|
||||
do {
|
||||
const auto & ubatch = kv_state->get_ubatch();
|
||||
const auto & ubatch = mstate->get_ubatch();
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
@@ -1015,11 +1048,14 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
ggml_status status;
|
||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
|
||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
|
||||
|
||||
if (!res) {
|
||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };
|
||||
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
|
||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
const auto & seq_id = ubatch.seq_id[i][0];
|
||||
@@ -1034,7 +1070,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
|
||||
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
||||
|
||||
llama_kv_self_seq_rm(this, s, pos_min[s], -1);
|
||||
memory->seq_rm(s, pos_min[s], -1);
|
||||
}
|
||||
|
||||
switch (status) {
|
||||
@@ -1128,7 +1164,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
}
|
||||
|
||||
n_outputs_prev += n_outputs;
|
||||
} while (kv_state->next());
|
||||
} while (mstate->next());
|
||||
|
||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||
n_outputs = n_outputs_all;
|
||||
@@ -1137,7 +1173,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
{
|
||||
bool sorted_output = true;
|
||||
|
||||
auto & out_ids = kv_state->out_ids();
|
||||
auto & out_ids = mstate->out_ids();
|
||||
|
||||
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
|
||||
|
||||
@@ -1189,11 +1225,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||
//synchronize();
|
||||
|
||||
// decide if we need to defrag the kv cache
|
||||
if (cparams.defrag_thold > 0.0f) {
|
||||
kv_self->defrag_sched(cparams.defrag_thold);
|
||||
}
|
||||
|
||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
@@ -1810,11 +1841,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
||||
}
|
||||
}
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
if (kv_self != nullptr) {
|
||||
if (memory != nullptr) {
|
||||
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
||||
kv_self->state_write(io);
|
||||
memory->state_write(io);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
@@ -1901,9 +1930,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
||||
if (memory) {
|
||||
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->state_read(io);
|
||||
memory->state_read(io);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
@@ -1913,9 +1940,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
|
||||
GGML_UNUSED(seq_id);
|
||||
|
||||
if (memory) {
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->state_write(io, seq_id);
|
||||
memory->state_write(io, seq_id);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
@@ -1925,9 +1950,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
|
||||
GGML_UNUSED(seq_id);
|
||||
|
||||
if (memory) {
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->state_read(io, seq_id);
|
||||
memory->state_read(io, seq_id);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
@@ -2032,9 +2055,7 @@ void llama_context::opt_epoch_iter(
|
||||
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
||||
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->clear();
|
||||
memory->clear(true);
|
||||
|
||||
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
||||
batch.n_tokens = n_batch;
|
||||
@@ -2057,8 +2078,8 @@ void llama_context::opt_epoch_iter(
|
||||
|
||||
int64_t n_outputs_all = n_tokens_all;
|
||||
|
||||
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
|
||||
if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
|
||||
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
||||
break;
|
||||
}
|
||||
@@ -2071,17 +2092,17 @@ void llama_context::opt_epoch_iter(
|
||||
|
||||
uint32_t pos_batch = 0;
|
||||
do {
|
||||
const auto & ubatch = kv_state->get_ubatch();
|
||||
const auto & ubatch = mstate->get_ubatch();
|
||||
|
||||
n_outputs = ubatch.n_tokens;
|
||||
|
||||
if (!kv_state->apply()) {
|
||||
if (!mstate->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
||||
break;
|
||||
}
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
|
||||
|
||||
struct ggml_context * ctx_compute_opt;
|
||||
{
|
||||
@@ -2116,7 +2137,7 @@ void llama_context::opt_epoch_iter(
|
||||
ggml_free(ctx_compute_opt);
|
||||
|
||||
pos_batch += ubatch.n_tokens;
|
||||
} while (kv_state->next());
|
||||
} while (mstate->next());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2277,13 +2298,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
|
||||
return &ctx->get_model();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
||||
return ctx->get_kv_self();
|
||||
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_update(llama_context * ctx) {
|
||||
ctx->kv_self_update();
|
||||
ctx->kv_self_update(false);
|
||||
}
|
||||
|
||||
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
||||
@@ -2398,13 +2420,118 @@ int32_t llama_apply_adapter_cvec(
|
||||
return res ? 0 : -1;
|
||||
}
|
||||
|
||||
//
|
||||
// memory
|
||||
//
|
||||
|
||||
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
||||
return ctx->get_memory();
|
||||
}
|
||||
|
||||
void llama_memory_clear(llama_memory_t mem, bool data) {
|
||||
if (!mem) {
|
||||
return;
|
||||
}
|
||||
|
||||
mem->clear(data);
|
||||
}
|
||||
|
||||
bool llama_memory_seq_rm(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
if (!mem) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return mem->seq_rm(seq_id, p0, p1);
|
||||
}
|
||||
|
||||
void llama_memory_seq_cp(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
if (!mem) {
|
||||
return;
|
||||
}
|
||||
|
||||
mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_memory_seq_keep(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id) {
|
||||
if (!mem) {
|
||||
return;
|
||||
}
|
||||
|
||||
mem->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
void llama_memory_seq_add(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
if (!mem) {
|
||||
return;
|
||||
}
|
||||
|
||||
mem->seq_add(seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
void llama_memory_seq_div(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
if (!mem) {
|
||||
return;
|
||||
}
|
||||
|
||||
mem->seq_div(seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_memory_seq_pos_min(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id) {
|
||||
if (!mem) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return mem->seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_memory_seq_pos_max(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id) {
|
||||
if (!mem) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return mem->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
bool llama_memory_can_shift(llama_memory_t mem) {
|
||||
if (!mem) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return mem->get_can_shift();
|
||||
}
|
||||
|
||||
//
|
||||
// kv cache
|
||||
//
|
||||
|
||||
// deprecated
|
||||
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
const auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return 0;
|
||||
}
|
||||
@@ -2426,7 +2553,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||
// deprecated
|
||||
// note: this is the same as above - will be removed anyway, so it's ok
|
||||
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
const auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return 0;
|
||||
}
|
||||
@@ -2445,115 +2572,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_clear(llama_context * ctx) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
kv->clear();
|
||||
llama_memory_clear(kv, true);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_kv_self_seq_rm(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return kv->seq_rm(seq_id, p0, p1);
|
||||
return llama_memory_seq_rm(kv, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_seq_cp(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
kv->seq_keep(seq_id);
|
||||
llama_memory_seq_keep(kv, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_seq_add(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
kv->seq_add(seq_id, p0, p1, delta);
|
||||
llama_memory_seq_add(kv, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_seq_div(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
kv->seq_div(seq_id, p0, p1, d);
|
||||
llama_memory_seq_div(kv, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return kv->seq_pos_min(seq_id);
|
||||
return llama_memory_seq_pos_min(kv, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return kv->seq_pos_max(seq_id);
|
||||
return llama_memory_seq_pos_max(kv, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_defrag(llama_context * ctx) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
// force defrag
|
||||
kv->defrag_sched(-1.0f);
|
||||
ctx->kv_self_defrag_sched();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return kv->get_can_shift();
|
||||
return llama_memory_can_shift(kv);
|
||||
}
|
||||
|
||||
// llama state API
|
||||
|
||||
@@ -13,13 +13,12 @@
|
||||
#include <vector>
|
||||
|
||||
struct llama_model;
|
||||
struct llama_kv_cache;
|
||||
|
||||
class llama_io_read_i;
|
||||
class llama_io_write_i;
|
||||
|
||||
class llama_memory_i;
|
||||
class llama_memory_state_i;
|
||||
struct llama_memory_i;
|
||||
struct llama_memory_state_i;
|
||||
|
||||
struct llama_context {
|
||||
// init scheduler and compute buffers, reserve worst-case graphs
|
||||
@@ -47,12 +46,12 @@ struct llama_context {
|
||||
uint32_t n_threads() const;
|
||||
uint32_t n_threads_batch() const;
|
||||
|
||||
llama_kv_cache * get_kv_self();
|
||||
const llama_kv_cache * get_kv_self() const;
|
||||
llama_memory_t get_memory() const;
|
||||
|
||||
// return true of the KV cache was updated
|
||||
// TODO: remove
|
||||
bool kv_self_update();
|
||||
bool kv_self_update(bool optimize);
|
||||
void kv_self_defrag_sched();
|
||||
|
||||
enum llama_pooling_type pooling_type() const;
|
||||
|
||||
@@ -231,6 +230,9 @@ private:
|
||||
|
||||
std::unique_ptr<llama_memory_i> memory;
|
||||
|
||||
// TODO: temporary, until the llama_kv_self_defrag() API is removed
|
||||
bool memory_force_optimize = false;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
float * logits = nullptr;
|
||||
|
||||
@@ -250,22 +250,6 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
|
||||
const int64_t n_kv = kv_state->get_n_kv();
|
||||
|
||||
if (s_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
|
||||
float * data = (float *) s_mask->data;
|
||||
|
||||
// clear unused states
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
data[i] = kv_state->s_mask(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
|
||||
@@ -659,6 +643,20 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
cur = ggml_mul(ctx0, x0, x1);
|
||||
cb(cur, "ffn_mul", il);
|
||||
} break;
|
||||
case LLM_FFN_GEGLU:
|
||||
{
|
||||
// Split into two equal parts
|
||||
int64_t split_point = cur->ne[0] / 2;
|
||||
// TODO: these conts should not be needed
|
||||
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
||||
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
||||
|
||||
x0 = ggml_gelu(ctx0, x0);
|
||||
cb(x0, "ffn_gelu", il);
|
||||
|
||||
cur = ggml_mul(ctx0, x0, x1);
|
||||
cb(cur, "ffn_geglu", il);
|
||||
} break;
|
||||
}
|
||||
|
||||
if (gate && type_gate == LLM_FFN_PAR) {
|
||||
@@ -769,9 +767,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
||||
|
||||
if (weight_before_ffn) {
|
||||
// TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
|
||||
ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
|
||||
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
|
||||
// repeat cur to [n_embd, n_expert_used, n_tokens]
|
||||
ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
|
||||
cur = ggml_mul(ctx0, repeated, weights);
|
||||
cb(cur, "ffn_moe_weighted", il);
|
||||
}
|
||||
@@ -973,23 +970,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
|
||||
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
|
||||
auto & cur = inp->s_mask;
|
||||
|
||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
|
||||
ggml_set_input(cur);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
||||
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
|
||||
|
||||
@@ -1442,43 +1422,53 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_copy_mask_state(
|
||||
ggml_tensor * llm_graph_context::build_recurrent_state(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
int32_t n_state,
|
||||
int32_t n_seqs) const {
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
const auto kv_head = kv_state->get_head();
|
||||
const auto rs_zero = kv_state->get_rs_z();
|
||||
|
||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
|
||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
|
||||
|
||||
// copy states
|
||||
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
||||
// this shrinks the tensors's ne[1] to n_kv
|
||||
states = ggml_get_rows(ctx0, states, state_copy);
|
||||
// Clear a single state which will then be copied to the other cleared states.
|
||||
// Note that this is a no-op when the view is zero-sized.
|
||||
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
||||
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
||||
|
||||
// clear states of sequences which are starting at the beginning of this batch
|
||||
// FIXME: zero-out NANs?
|
||||
states = ggml_mul(ctx0, states, state_mask);
|
||||
ggml_tensor * output_states;
|
||||
|
||||
// copy states which won't be changed further (between n_seqs and n_kv)
|
||||
if (!avoid_copies) {
|
||||
// copy states
|
||||
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
||||
// {state_size, kv_size} -> {state_size, n_seqs}
|
||||
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
||||
ggml_build_forward_expand(gf, output_states);
|
||||
} else {
|
||||
// FIXME: make the gathering operation happen before the copy below
|
||||
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
|
||||
output_states = states;
|
||||
}
|
||||
|
||||
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
||||
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0,
|
||||
ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
|
||||
ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
|
||||
states_extra,
|
||||
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
|
||||
|
||||
// the part of the states that will be used and modified
|
||||
return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
|
||||
return output_states;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
@@ -1489,8 +1479,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||
|
||||
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
|
||||
|
||||
ggml_tensor * token_shift = build_copy_mask_state(
|
||||
gf, token_shift_all, state_copy, state_mask,
|
||||
ggml_tensor * token_shift = build_recurrent_state(
|
||||
gf, token_shift_all, state_copy,
|
||||
hparams.n_embd_k_s(), n_seqs);
|
||||
|
||||
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
||||
|
||||
@@ -17,7 +17,7 @@ struct ggml_tensor;
|
||||
struct llama_ubatch;
|
||||
struct llama_cparams;
|
||||
|
||||
class llama_memory_state_i;
|
||||
struct llama_memory_state_i;
|
||||
|
||||
class llama_kv_cache_unified_state;
|
||||
class llama_kv_cache_unified_iswa_state;
|
||||
@@ -36,6 +36,7 @@ enum llm_ffn_op_type {
|
||||
LLM_FFN_RELU,
|
||||
LLM_FFN_RELU_SQR,
|
||||
LLM_FFN_SWIGLU,
|
||||
LLM_FFN_GEGLU,
|
||||
};
|
||||
|
||||
enum llm_ffn_gate_type {
|
||||
@@ -199,18 +200,6 @@ public:
|
||||
const llama_kv_cache_recurrent_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_s_mask : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
||||
virtual ~llm_graph_input_s_mask() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_mask; // F32 [1, n_kv]
|
||||
|
||||
const llama_kv_cache_recurrent_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_cross_embd(
|
||||
@@ -520,7 +509,6 @@ struct llm_graph_context {
|
||||
ggml_tensor * build_inp_mean() const;
|
||||
ggml_tensor * build_inp_cls() const;
|
||||
ggml_tensor * build_inp_s_copy() const;
|
||||
ggml_tensor * build_inp_s_mask() const;
|
||||
|
||||
ggml_tensor * build_inp_cross_embd() const;
|
||||
ggml_tensor * build_inp_pos_bucket_enc() const;
|
||||
@@ -605,18 +593,17 @@ struct llm_graph_context {
|
||||
// recurrent
|
||||
//
|
||||
|
||||
ggml_tensor * build_copy_mask_state(
|
||||
ggml_tensor * build_recurrent_state(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
int32_t n_state,
|
||||
int32_t n_seqs) const;
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies = false) const;
|
||||
|
||||
ggml_tensor * build_rwkv_token_shift_load(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const;
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "llama-kv-cache-recurrent.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-model.h"
|
||||
|
||||
@@ -116,18 +117,21 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::clear() {
|
||||
void llama_kv_cache_recurrent::clear(bool data) {
|
||||
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
||||
cells[i].pos = -1;
|
||||
cells[i].seq_id.clear();
|
||||
cells[i].src = -1;
|
||||
cells[i].tail = -1;
|
||||
}
|
||||
|
||||
head = 0;
|
||||
used = 0;
|
||||
|
||||
for (auto & buf : bufs) {
|
||||
ggml_backend_buffer_clear(buf.get(), 0);
|
||||
if (data) {
|
||||
for (auto & buf : bufs) {
|
||||
ggml_backend_buffer_clear(buf.get(), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -386,6 +390,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
|
||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
|
||||
GGML_UNUSED(lctx);
|
||||
GGML_UNUSED(optimize);
|
||||
|
||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
// simply remember the full state because it is very small for this type of cache
|
||||
// TODO: optimize
|
||||
@@ -395,21 +406,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
||||
|
||||
bool success = true;
|
||||
|
||||
// TODO: here we have to verify that all ubatches can fit in the cells
|
||||
// however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
|
||||
// during the compute of each ubatch. to reproduce, uncomment the following loop and run:
|
||||
//
|
||||
// $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
|
||||
//
|
||||
// recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
|
||||
//
|
||||
GGML_UNUSED(ubatches);
|
||||
//for (const auto & ubatch : ubatches) {
|
||||
// if (!find_slot(ubatch)) {
|
||||
// success = false;
|
||||
// break;
|
||||
// }
|
||||
//}
|
||||
for (const auto & ubatch : ubatches) {
|
||||
if (!find_slot(ubatch)) {
|
||||
success = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// restore the original state
|
||||
cells = std::move(org_cells);
|
||||
@@ -419,26 +421,14 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
||||
return success;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::update(llama_context & lctx) {
|
||||
GGML_UNUSED(lctx);
|
||||
// noop
|
||||
return false;
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::defrag_sched(float thold) {
|
||||
GGML_UNUSED(thold);
|
||||
// noop
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
const uint32_t n_seqs = ubatch.n_seqs;
|
||||
const uint32_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (head > used + 2*n_tokens) {
|
||||
if (head > used + 2*n_seqs) {
|
||||
head = 0;
|
||||
}
|
||||
|
||||
@@ -534,16 +524,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
empty_cell.src = orig_cell.src;
|
||||
orig_cell.seq_id.erase(seq_id);
|
||||
empty_cell.seq_id.insert(seq_id); // will be overwritten
|
||||
GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
|
||||
}
|
||||
seq_meta.tail = next_empty_cell;
|
||||
// find next empty cell
|
||||
if (s + 1 < n_seqs) {
|
||||
next_empty_cell += 1;
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
next_empty_cell += 1;
|
||||
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
||||
kv_cell & cell = cells[next_empty_cell];
|
||||
if (cell.is_empty()) { break; }
|
||||
next_empty_cell += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -553,8 +543,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
|
||||
// gather and re-order
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
int32_t dst_id = s + min;
|
||||
int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
||||
const int32_t dst_id = s + min;
|
||||
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
||||
if (dst_id != src_id) {
|
||||
kv_cell & dst_cell = cells[dst_id];
|
||||
kv_cell & src_cell = cells[src_id];
|
||||
@@ -563,12 +553,14 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
std::swap(dst_cell.src, src_cell.src);
|
||||
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
||||
|
||||
// swap tails (assuming they NEVER overlap)
|
||||
for (const llama_seq_id seq_id : src_cell.seq_id) {
|
||||
cells[seq_id].tail = src_id;
|
||||
}
|
||||
for (const llama_seq_id seq_id : dst_cell.seq_id) {
|
||||
cells[seq_id].tail = dst_id;
|
||||
// swap tails
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
int32_t & tail = cells[i].tail;
|
||||
if (tail == src_id) {
|
||||
tail = dst_id;
|
||||
} else if (tail == dst_id) {
|
||||
tail = src_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -576,7 +568,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
// update the pos of the used seqs
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
||||
int32_t cell_id = s + min;
|
||||
const int32_t cell_id = s + min;
|
||||
kv_cell & cell = cells[cell_id];
|
||||
|
||||
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
||||
@@ -594,6 +586,38 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
// Find first cell without src refs, to use as the zero-ed state
|
||||
{
|
||||
// TODO: bake-in src refcounts in the cell metadata
|
||||
std::vector<int32_t> refcounts(size, 0);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
const int32_t src = cells[i].src;
|
||||
if (src >= 0) {
|
||||
refcounts[src] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
rs_z = -1;
|
||||
for (int i = min; i <= max; ++i) {
|
||||
if (refcounts[i] == 0) {
|
||||
rs_z = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = min; i <= max; ++i) {
|
||||
if (cells[i].src < 0) {
|
||||
GGML_ASSERT(rs_z >= 0);
|
||||
cells[i].src0 = rs_z;
|
||||
} else {
|
||||
// Stage the source ids for all used cells to allow correct seq_* behavior
|
||||
// and still make these values available when setting the inputs
|
||||
cells[i].src0 = cells[i].src;
|
||||
}
|
||||
cells[i].src = i; // avoid moving or clearing twice
|
||||
}
|
||||
}
|
||||
|
||||
// allow getting the range of used cells, from head to head + n
|
||||
head = min;
|
||||
n = max - min + 1;
|
||||
@@ -605,47 +629,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::get_can_shift() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_recurrent::s_copy(int i) const {
|
||||
const uint32_t cell_id = i + head;
|
||||
|
||||
//////////////////////////////////////////////
|
||||
// TODO: this should not mutate the KV cache !
|
||||
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
|
||||
|
||||
// prevent out-of-bound sources
|
||||
if (cell.src < 0 || (uint32_t) cell.src >= size) {
|
||||
cell.src = cell_id;
|
||||
}
|
||||
|
||||
int32_t res = cell.src;
|
||||
|
||||
// TODO: do not mutate the KV cache
|
||||
// ensure copy only happens once
|
||||
if (cell.src != (int32_t) cell_id) {
|
||||
cell.src = cell_id;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
float llama_kv_cache_recurrent::s_mask(int i) const {
|
||||
const uint32_t cell_id = i + head;
|
||||
|
||||
//////////////////////////////////////////////
|
||||
// TODO: this should not mutate the KV cache !
|
||||
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
|
||||
|
||||
float res = (float) (cell.src >= 0);
|
||||
|
||||
// only clear once
|
||||
if (cell.src < 0) {
|
||||
cell.src = cell_id;
|
||||
}
|
||||
|
||||
return res;
|
||||
// shifting the pos is trivial for recurrent models
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t llama_kv_cache_recurrent::total_size() const {
|
||||
@@ -726,7 +711,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
|
||||
|
||||
if (!res) {
|
||||
if (seq_id == -1) {
|
||||
clear();
|
||||
clear(true);
|
||||
} else {
|
||||
seq_rm(seq_id, -1, -1);
|
||||
}
|
||||
@@ -883,7 +868,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
||||
return false;
|
||||
}
|
||||
|
||||
clear();
|
||||
clear(true);
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
kv_cell & cell = cells[i];
|
||||
@@ -1111,6 +1096,10 @@ uint32_t llama_kv_cache_recurrent_state::get_head() const {
|
||||
return is_full ? 0 : kv->head;
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
|
||||
return is_full ? 0 : kv->rs_z;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_recurrent_state::get_size() const {
|
||||
return kv->size;
|
||||
}
|
||||
@@ -1124,9 +1113,5 @@ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
|
||||
return kv->s_copy(i);
|
||||
}
|
||||
|
||||
float llama_kv_cache_recurrent_state::s_mask(int i) const {
|
||||
return kv->s_mask(i);
|
||||
return kv->cells[i + kv->head].src0;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-memory.h"
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
||||
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
||||
class llama_kv_cache_recurrent : public llama_kv_cache {
|
||||
class llama_kv_cache_recurrent : public llama_memory_i {
|
||||
public:
|
||||
llama_kv_cache_recurrent(
|
||||
const llama_model & model,
|
||||
@@ -29,7 +29,17 @@ public:
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
void clear() override;
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
void clear(bool data) override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
@@ -40,22 +50,6 @@ public:
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
// find a contiguous slot of kv cells and emplace the ubatch there
|
||||
@@ -63,10 +57,6 @@ public:
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
|
||||
int32_t s_copy(int i) const;
|
||||
float s_mask(int i) const;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
@@ -79,10 +69,14 @@ public:
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
// first zero-ed state
|
||||
int32_t rs_z = -1;
|
||||
|
||||
// TODO: optimize for recurrent state needs
|
||||
struct kv_cell {
|
||||
llama_pos pos = -1;
|
||||
int32_t src = -1; // used to copy states
|
||||
int32_t src = -1; // used to know where states should be copied from
|
||||
int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
|
||||
int32_t tail = -1;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
@@ -163,13 +157,13 @@ public:
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
uint32_t get_head() const;
|
||||
int32_t get_rs_z() const;
|
||||
uint32_t get_size() const;
|
||||
|
||||
ggml_tensor * get_k_l(int32_t il) const;
|
||||
ggml_tensor * get_v_l(int32_t il) const;
|
||||
|
||||
int32_t s_copy(int i) const;
|
||||
float s_mask(int i) const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
|
||||
@@ -52,9 +52,9 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||
hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::clear() {
|
||||
kv_base->clear();
|
||||
kv_swa ->clear();
|
||||
void llama_kv_cache_unified_iswa::clear(bool data) {
|
||||
kv_base->clear(data);
|
||||
kv_swa ->clear(data);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
@@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
|
||||
|
||||
assert(heads_base.size() == heads_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
||||
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
|
||||
bool res = false;
|
||||
|
||||
res = res | kv_base->update(lctx);
|
||||
res = res | kv_swa ->update(lctx);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
|
||||
kv_base->defrag_sched(thold);
|
||||
kv_swa ->defrag_sched(thold);
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
||||
@@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv) : status(status) {
|
||||
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
|
||||
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
|
||||
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
state_base = kv->get_base()->init_full();
|
||||
state_swa = kv->get_swa ()->init_full();
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
state_base = kv->get_base()->init_update(lctx, optimize);
|
||||
state_swa = kv->get_swa ()->init_update(lctx, optimize);
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
std::vector<llama_ubatch> ubatches)
|
||||
: status(status),
|
||||
sbatch(std::move(sbatch)),
|
||||
ubatches(std::move(ubatches)) {
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
||||
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
||||
}
|
||||
: status(LLAMA_MEMORY_STATUS_SUCCESS),
|
||||
sbatch(std::move(sbatch)),
|
||||
ubatches(std::move(ubatches)) {
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
||||
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
||||
|
||||
@@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
||||
|
||||
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return state_base.get();
|
||||
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return state_swa.get();
|
||||
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
// utilizes two instances of llama_kv_cache_unified
|
||||
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
||||
|
||||
class llama_kv_cache_unified_iswa : public llama_kv_cache {
|
||||
class llama_kv_cache_unified_iswa : public llama_memory_i {
|
||||
public:
|
||||
llama_kv_cache_unified_iswa(
|
||||
const llama_model & model,
|
||||
@@ -31,7 +31,19 @@ public:
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
void clear() override;
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
void clear(bool data) override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
@@ -42,24 +54,6 @@ public:
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
@@ -86,12 +80,16 @@ public:
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv);
|
||||
|
||||
// used to create an update state
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize);
|
||||
|
||||
// used to create a state from a batch
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_base,
|
||||
@@ -120,7 +118,7 @@ public:
|
||||
const llama_kv_cache_unified_state * get_swa() const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
llama_memory_status status;
|
||||
|
||||
//llama_kv_cache_unified_iswa * kv;
|
||||
|
||||
@@ -131,6 +129,6 @@ private:
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified_state> state_base;
|
||||
std::unique_ptr<llama_kv_cache_unified_state> state_swa;
|
||||
llama_memory_state_ptr state_base;
|
||||
llama_memory_state_ptr state_swa;
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "llama-kv-cache-unified.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-context.h"
|
||||
|
||||
@@ -128,13 +129,15 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::clear() {
|
||||
void llama_kv_cache_unified::clear(bool data) {
|
||||
cells.reset();
|
||||
|
||||
head = 0;
|
||||
|
||||
for (auto & buf : bufs) {
|
||||
ggml_backend_buffer_clear(buf.get(), 0);
|
||||
if (data) {
|
||||
for (auto & buf : bufs) {
|
||||
ggml_backend_buffer_clear(buf.get(), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,12 +152,27 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
||||
p1 = std::numeric_limits<llama_pos>::max();
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
if (seq_id >= 0) {
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
|
||||
if (new_head == cells.size()) {
|
||||
new_head = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// match any sequence
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cells.rm(i);
|
||||
|
||||
if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
|
||||
if (new_head == cells.size()) {
|
||||
new_head = i;
|
||||
}
|
||||
@@ -305,16 +323,49 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS,
|
||||
return std::make_unique<llama_kv_cache_unified_state>(
|
||||
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
||||
return std::make_unique<llama_kv_cache_unified_state>(this);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
std::vector<uint32_t> res;
|
||||
llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
||||
bool do_shift = get_has_shift();
|
||||
|
||||
defrag_info dinfo;
|
||||
|
||||
// see if we need to defrag
|
||||
{
|
||||
bool do_defrag = optimize;
|
||||
|
||||
const auto thold = lctx->get_cparams().defrag_thold;
|
||||
|
||||
if (!do_defrag && thold > 0.0f) {
|
||||
const auto n_kv = cells.used_max_p1();
|
||||
|
||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||
// - count the padding towards the number of used tokens
|
||||
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
||||
|
||||
if (fragmentation > thold) {
|
||||
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
||||
|
||||
do_defrag = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (do_defrag) {
|
||||
dinfo = defrag_prepare(lctx->graph_max_nodes());
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
|
||||
}
|
||||
|
||||
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
llama_kv_cache_unified::ubatch_heads res;
|
||||
|
||||
struct state {
|
||||
uint32_t head_old; // old position of the head, before placing the ubatch
|
||||
@@ -359,12 +410,12 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::update(llama_context & lctx) {
|
||||
bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
|
||||
bool updated = false;
|
||||
|
||||
auto * sched = lctx.get_sched();
|
||||
auto * sched = lctx->get_sched();
|
||||
|
||||
if (cells.get_has_shift()) {
|
||||
if (do_shift) {
|
||||
if (!get_can_shift()) {
|
||||
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
||||
}
|
||||
@@ -375,9 +426,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
||||
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
||||
ggml_backend_sched_reset(sched);
|
||||
|
||||
auto * gf = lctx.graph_init();
|
||||
auto * gf = lctx->graph_init();
|
||||
|
||||
auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
||||
auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
|
||||
return updated;
|
||||
@@ -390,7 +441,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
||||
|
||||
res->set_inputs(nullptr);
|
||||
|
||||
if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||
if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
@@ -401,56 +452,55 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
||||
cells.reset_shift();
|
||||
}
|
||||
|
||||
if (do_defrag) {
|
||||
if (!dinfo.empty()) {
|
||||
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
||||
|
||||
if (defrag_prepare(lctx.graph_max_nodes())) {
|
||||
ggml_backend_sched_reset(sched);
|
||||
// apply moves:
|
||||
{
|
||||
const auto n_kv = dinfo.ids.size();
|
||||
|
||||
auto * gf = lctx.graph_init();
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
assert(dinfo.ids[i] <= n_kv);
|
||||
|
||||
auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
|
||||
return updated;
|
||||
if (dinfo.ids[i] == n_kv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cells.mv(i, dinfo.ids[i]);
|
||||
}
|
||||
|
||||
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
res->set_inputs(nullptr);
|
||||
|
||||
if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
updated = true;
|
||||
// reset the head so we can find the first free slot during the next ubatch
|
||||
head = 0;
|
||||
}
|
||||
|
||||
do_defrag = false;
|
||||
ggml_backend_sched_reset(sched);
|
||||
|
||||
auto * gf = lctx->graph_init();
|
||||
|
||||
auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
res->set_inputs(nullptr);
|
||||
|
||||
if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
updated = true;
|
||||
}
|
||||
|
||||
return updated;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::defrag_sched(float thold) {
|
||||
const auto n_kv = cells.used_max_p1();
|
||||
|
||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||
// - count the padding towards the number of used tokens
|
||||
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
||||
|
||||
// queue defragmentation for next llama_kv_cache_update
|
||||
if (fragmentation > thold) {
|
||||
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
||||
|
||||
do_defrag = true;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
@@ -462,8 +512,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||
head_cur = 0;
|
||||
}
|
||||
|
||||
// otherwise, one cell per token.
|
||||
|
||||
if (n_tokens > cells.size()) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
||||
return -1;
|
||||
@@ -597,6 +645,10 @@ uint32_t llama_kv_cache_unified::get_size() const {
|
||||
return cells.size();
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::get_has_shift() const {
|
||||
return cells.get_has_shift();
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
||||
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
||||
}
|
||||
@@ -926,12 +978,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
||||
}
|
||||
|
||||
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf) const {
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const defrag_info & dinfo) const {
|
||||
auto res = std::make_unique<llm_graph_result>();
|
||||
|
||||
const auto & ids = defrag_info.ids;
|
||||
const auto & ids = dinfo.ids;
|
||||
|
||||
#if 0
|
||||
// CPU defrag
|
||||
@@ -1072,7 +1125,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
|
||||
const uint32_t n_layer = layers.size();
|
||||
|
||||
const uint32_t n_kv = cells.used_max_p1();
|
||||
@@ -1093,14 +1146,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
|
||||
|
||||
// determine which KV cells to move where
|
||||
//
|
||||
// cell i moves to ids[i]
|
||||
//
|
||||
// if ids[i] == i || ids[i] == n_kv, then cell i is not moved
|
||||
//
|
||||
auto & ids = defrag_info.ids;
|
||||
defrag_info res;
|
||||
auto & ids = res.ids;
|
||||
|
||||
ids.clear();
|
||||
ids.resize(n_kv, n_kv);
|
||||
|
||||
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
||||
@@ -1164,11 +1212,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
// this cell goes to (i0 + nf)
|
||||
ids[i1] = i0 + nf;
|
||||
|
||||
// move the cell meta data
|
||||
cells.mv(i1, i0 + nf);
|
||||
|
||||
head = n_used;
|
||||
|
||||
if (!cont) {
|
||||
n_moves++;
|
||||
cont = true;
|
||||
@@ -1191,14 +1234,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
}
|
||||
|
||||
if (n_moves == 0) {
|
||||
return false;
|
||||
return {};
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
|
||||
|
||||
return true;
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
@@ -1276,7 +1319,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
||||
|
||||
if (!res) {
|
||||
if (seq_id == -1) {
|
||||
clear();
|
||||
clear(true);
|
||||
} else {
|
||||
seq_rm(seq_id, -1, -1);
|
||||
}
|
||||
@@ -1457,7 +1500,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
return false;
|
||||
}
|
||||
|
||||
clear();
|
||||
clear(true);
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
llama_pos pos;
|
||||
@@ -1621,24 +1664,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv) : status(status), kv(kv) {
|
||||
n_kv = kv->get_size();
|
||||
head = 0;
|
||||
}
|
||||
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||
n_kv = kv->get_size();
|
||||
head = 0;
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads,
|
||||
std::vector<llama_ubatch> ubatches)
|
||||
: status(status),
|
||||
kv(kv),
|
||||
sbatch(std::move(sbatch)),
|
||||
heads(std::move(heads)),
|
||||
ubatches(std::move(ubatches)) {
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
|
||||
if (!do_shift && dinfo.empty()) {
|
||||
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
|
||||
}
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_sbatch sbatch,
|
||||
llama_kv_cache_unified::ubatch_heads heads,
|
||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
|
||||
|
||||
@@ -1655,6 +1701,13 @@ bool llama_kv_cache_unified_state::next() {
|
||||
bool llama_kv_cache_unified_state::apply() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
// no ubatches -> this is a KV cache update
|
||||
if (ubatches.empty()) {
|
||||
kv->update(lctx, do_shift, dinfo);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
|
||||
|
||||
n_kv = kv->get_n_kv();
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cells.h"
|
||||
#include "llama-memory.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
@@ -17,13 +17,26 @@ struct llama_context;
|
||||
// llama_kv_cache_unified
|
||||
//
|
||||
|
||||
class llama_kv_cache_unified : public llama_kv_cache {
|
||||
class llama_kv_cache_unified : public llama_memory_i {
|
||||
public:
|
||||
static uint32_t get_padding(const llama_cparams & cparams);
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
using ubatch_heads = std::vector<uint32_t>;
|
||||
|
||||
struct defrag_info {
|
||||
bool empty() const {
|
||||
return ids.empty();
|
||||
}
|
||||
|
||||
// contains information about which cell moves where:
|
||||
// - cell i moves to ids[i]
|
||||
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
|
||||
std::vector<uint32_t> ids;
|
||||
};
|
||||
|
||||
llama_kv_cache_unified(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
@@ -43,7 +56,19 @@ public:
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
void clear() override;
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
void clear(bool data) override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
@@ -54,24 +79,6 @@ public:
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
@@ -83,6 +90,8 @@ public:
|
||||
|
||||
uint32_t get_size() const;
|
||||
|
||||
bool get_has_shift() const;
|
||||
|
||||
//
|
||||
// graph_build API
|
||||
//
|
||||
@@ -103,7 +112,9 @@ public:
|
||||
|
||||
// find places for the provided ubatches in the cache, returns the head locations
|
||||
// return empty vector on failure
|
||||
std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
||||
|
||||
// return the cell position where we can insert the ubatch
|
||||
// return -1 on failure to find a contiguous slot of kv cells
|
||||
@@ -133,8 +144,7 @@ private:
|
||||
ggml_tensor * v;
|
||||
};
|
||||
|
||||
bool do_defrag = false;
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
@@ -160,13 +170,8 @@ private:
|
||||
// model layer id -> KV cache layer id
|
||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||
|
||||
// defrag
|
||||
struct {
|
||||
std::vector<uint32_t> ids;
|
||||
} defrag_info;
|
||||
|
||||
// return true if cells have been moved
|
||||
bool defrag_prepare(int32_t n_max_nodes);
|
||||
// return non-empty vector if cells have been moved
|
||||
defrag_info defrag_prepare(int32_t n_max_nodes) const;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
@@ -192,7 +197,8 @@ private:
|
||||
llm_graph_result_ptr build_graph_defrag(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf) const;
|
||||
ggml_cgraph * gf,
|
||||
const defrag_info & dinfo) const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
@@ -203,20 +209,29 @@ private:
|
||||
|
||||
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
||||
public:
|
||||
// some shorthands
|
||||
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_state(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv);
|
||||
|
||||
// used to create a state from a batch
|
||||
// used to create an update state
|
||||
llama_kv_cache_unified_state(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo);
|
||||
|
||||
// used to create a decode state from a batch
|
||||
llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads,
|
||||
ubatch_heads heads,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_state();
|
||||
@@ -253,16 +268,30 @@ public:
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
llama_memory_status status;
|
||||
|
||||
llama_kv_cache_unified * kv;
|
||||
llama_context * lctx;
|
||||
|
||||
//
|
||||
// update state
|
||||
//
|
||||
|
||||
bool do_shift = false;
|
||||
|
||||
defrag_info dinfo;
|
||||
|
||||
//
|
||||
// batch processing state
|
||||
//
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<uint32_t> heads;
|
||||
ubatch_heads heads;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
//
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
#include "llama-kv-cache.h"
|
||||
@@ -1,44 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-memory.h"
|
||||
|
||||
struct llama_kv_cache : public llama_memory_i {
|
||||
virtual ~llama_kv_cache() = default;
|
||||
|
||||
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
||||
// return a state object containing the ubatches and KV cache state required to process them
|
||||
// check the llama_memory_state_i::get_status() for the result
|
||||
virtual llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) = 0;
|
||||
|
||||
// simulate full cache, used for allocating worst-case compute buffers
|
||||
virtual llama_memory_state_ptr init_full() = 0;
|
||||
|
||||
// process any pending defrag/shift/etc. operations
|
||||
// optionally call once before processing a new batch
|
||||
// return true if any operations were performed
|
||||
virtual bool update(llama_context & lctx) = 0;
|
||||
|
||||
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
|
||||
// TODO: change to
|
||||
// llama_memory_state_ptr init_defrag(float thold) = 0;
|
||||
//
|
||||
virtual void defrag_sched(float thold) = 0;
|
||||
|
||||
// getters
|
||||
virtual bool get_can_shift() const = 0;
|
||||
|
||||
bool get_can_edit() const override { return get_can_shift(); }
|
||||
|
||||
//
|
||||
// state write/read
|
||||
//
|
||||
|
||||
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
||||
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
||||
};
|
||||
@@ -1 +1,42 @@
|
||||
#include "llama-memory.h"
|
||||
|
||||
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) {
|
||||
bool has_update = false;
|
||||
|
||||
switch (s0) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
has_update = true;
|
||||
break;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||
{
|
||||
break;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||
{
|
||||
return s0;
|
||||
}
|
||||
}
|
||||
|
||||
switch (s1) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
has_update = true;
|
||||
break;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||
{
|
||||
break;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||
{
|
||||
return s1;
|
||||
}
|
||||
}
|
||||
|
||||
// if either status has an update, then the combined status has an update
|
||||
return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
|
||||
}
|
||||
|
||||
@@ -7,6 +7,9 @@
|
||||
|
||||
struct llama_ubatch;
|
||||
|
||||
class llama_io_write_i;
|
||||
class llama_io_read_i;
|
||||
|
||||
struct llama_memory_params {
|
||||
// kv cache
|
||||
ggml_type type_k;
|
||||
@@ -16,32 +19,17 @@ struct llama_memory_params {
|
||||
bool swa_full;
|
||||
};
|
||||
|
||||
// general concept of LLM memory
|
||||
// the KV cache is a type of LLM memory, but there can be other types
|
||||
class llama_memory_i {
|
||||
public:
|
||||
virtual ~llama_memory_i() = default;
|
||||
|
||||
virtual void clear() = 0;
|
||||
|
||||
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
|
||||
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||
|
||||
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
|
||||
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
|
||||
|
||||
virtual bool get_can_edit() const = 0;
|
||||
};
|
||||
|
||||
enum llama_memory_status {
|
||||
LLAMA_MEMORY_STATUS_SUCCESS = 0,
|
||||
LLAMA_MEMORY_STATUS_NO_UPDATE,
|
||||
LLAMA_MEMORY_STATUS_FAILED_PREPARE,
|
||||
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
||||
};
|
||||
|
||||
// helper function for combining the status of two memory states
|
||||
// useful for implementing hybrid memory types (e.g. iSWA)
|
||||
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
|
||||
|
||||
// the interface for managing the memory state during batch processing
|
||||
// this interface is implemented per memory type. see:
|
||||
// - llama_kv_cache_unified_state
|
||||
@@ -51,8 +39,7 @@ enum llama_memory_status {
|
||||
// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
|
||||
//
|
||||
// TODO: rename to llama_memory_context_i ?
|
||||
class llama_memory_state_i {
|
||||
public:
|
||||
struct llama_memory_state_i {
|
||||
virtual ~llama_memory_state_i() = default;
|
||||
|
||||
// consume the current ubatch from the state and proceed to the next one
|
||||
@@ -69,8 +56,63 @@ public:
|
||||
// get the current ubatch
|
||||
virtual const llama_ubatch & get_ubatch() const = 0;
|
||||
|
||||
// get the status of the memory state
|
||||
// get the status of the memory state - used for error handling and checking if any updates would be applied
|
||||
virtual llama_memory_status get_status() const = 0;
|
||||
};
|
||||
|
||||
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
|
||||
|
||||
// general concept of LLM memory
|
||||
// the KV cache is a type of LLM memory, but there can be other types
|
||||
struct llama_memory_i {
|
||||
virtual ~llama_memory_i() = default;
|
||||
|
||||
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
||||
// return a state object containing the ubatches and KV cache state required to process them
|
||||
// check the llama_memory_state_i::get_status() for the result
|
||||
virtual llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) = 0;
|
||||
|
||||
// simulate full cache, used for allocating worst-case compute buffers
|
||||
virtual llama_memory_state_ptr init_full() = 0;
|
||||
|
||||
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
||||
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
||||
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
||||
|
||||
// getters
|
||||
virtual bool get_can_shift() const = 0;
|
||||
|
||||
//
|
||||
// ops
|
||||
//
|
||||
|
||||
// if data == true, the data buffers will also be cleared together with the metadata
|
||||
virtual void clear(bool data) = 0;
|
||||
|
||||
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
|
||||
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||
|
||||
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
|
||||
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
|
||||
|
||||
//
|
||||
// state write/read
|
||||
//
|
||||
|
||||
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
||||
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
||||
};
|
||||
|
||||
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
||||
|
||||
// TODO: temporary until the llama_kv_cache is removed from the public API
|
||||
struct llama_kv_cache : public llama_memory_i {
|
||||
virtual ~llama_kv_cache() = default;
|
||||
};
|
||||
|
||||
@@ -401,7 +401,7 @@ struct llama_mmap::impl {
|
||||
}
|
||||
}
|
||||
#else
|
||||
throw std::runtime_error("PrefetchVirtualMemory unavailable");
|
||||
LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user