mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d3726278b | ||
|
|
08f9d3cc1d | ||
|
|
0a540f9abd | ||
|
|
22577583a3 | ||
|
|
d9e03db1e7 | ||
|
|
db97837385 | ||
|
|
017761daf5 | ||
|
|
c42712b056 | ||
|
|
09c7c50e64 | ||
|
|
f334b79494 | ||
|
|
a28e3c7567 | ||
|
|
e31b5c55c3 | ||
|
|
21f24f27a9 | ||
|
|
7b43f55753 | ||
|
|
444f00b0ec |
31
.github/actions/windows-setup-cuda/action.yml
vendored
31
.github/actions/windows-setup-cuda/action.yml
vendored
@@ -65,3 +65,34 @@ runs:
|
||||
echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
|
||||
echo "CUDA_PATH_V12_4=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
|
||||
|
||||
- name: Install Cuda Toolkit 13.1
|
||||
if: ${{ inputs.cuda_version == '13.1' }}
|
||||
shell: pwsh
|
||||
run: |
|
||||
mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1"
|
||||
choco install unzip -y
|
||||
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_crt/windows-x86_64/cuda_crt-windows-x86_64-13.1.80-archive.zip"
|
||||
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-13.1.80-archive.zip"
|
||||
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-13.1.80-archive.zip"
|
||||
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-13.1.80-archive.zip"
|
||||
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-13.2.0.9-archive.zip"
|
||||
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libnvvm/windows-x86_64/libnvvm-windows-x86_64-13.1.80-archive.zip"
|
||||
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-13.1.68-archive.zip"
|
||||
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-13.1.80-archive.zip"
|
||||
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-13.1.68-archive.zip"
|
||||
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-13.1.78-archive.zip"
|
||||
unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1"
|
||||
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1\cuda_crt-windows-x86_64-13.1.80-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1" /E /I /H /Y
|
||||
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1\cuda_cudart-windows-x86_64-13.1.80-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1" /E /I /H /Y
|
||||
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1\cuda_nvcc-windows-x86_64-13.1.80-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1" /E /I /H /Y
|
||||
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1\cuda_nvrtc-windows-x86_64-13.1.80-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1" /E /I /H /Y
|
||||
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1\libcublas-windows-x86_64-13.2.0.9-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1" /E /I /H /Y
|
||||
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1\libnvvm-windows-x86_64-13.1.80-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1" /E /I /H /Y
|
||||
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1\cuda_nvtx-windows-x86_64-13.1.68-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1" /E /I /H /Y
|
||||
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1\cuda_profiler_api-windows-x86_64-13.1.80-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1" /E /I /H /Y
|
||||
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1\visual_studio_integration-windows-x86_64-13.1.68-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1" /E /I /H /Y
|
||||
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1\cuda_cccl-windows-x86_64-13.1.78-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1" /E /I /H /Y
|
||||
echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
|
||||
echo "CUDA_PATH_V13_1=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
|
||||
|
||||
6
.github/workflows/release.yml
vendored
6
.github/workflows/release.yml
vendored
@@ -421,7 +421,7 @@ jobs:
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
cuda: ['12.4']
|
||||
cuda: ['12.4', '13.1']
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -476,6 +476,7 @@ jobs:
|
||||
$dst='.\build\bin\cudart\'
|
||||
robocopy "${{env.CUDA_PATH}}\bin" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll
|
||||
robocopy "${{env.CUDA_PATH}}\lib" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll
|
||||
robocopy "${{env.CUDA_PATH}}\bin\x64" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll
|
||||
7z a cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip $dst\*
|
||||
|
||||
- name: Upload Cuda runtime
|
||||
@@ -835,7 +836,8 @@ jobs:
|
||||
**Windows:**
|
||||
- [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip)
|
||||
- [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip)
|
||||
- [Windows x64 (CUDA)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip)
|
||||
- [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip)
|
||||
- [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.1-x64.zip)
|
||||
- [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip)
|
||||
- [Windows x64 (SYCL)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip)
|
||||
- [Windows x64 (HIP)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-hip-radeon-x64.zip)
|
||||
|
||||
@@ -276,6 +276,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
| [MUSA](docs/build.md#musa) | Moore Threads GPU |
|
||||
| [CUDA](docs/build.md#cuda) | Nvidia GPU |
|
||||
| [HIP](docs/build.md#hip) | AMD GPU |
|
||||
| [ZenDNN](docs/build.md#zendnn) | AMD CPU |
|
||||
| [Vulkan](docs/build.md#vulkan) | GPU |
|
||||
| [CANN](docs/build.md#cann) | Ascend NPU |
|
||||
| [OpenCL](docs/backend/OPENCL.md) | Adreno GPU |
|
||||
|
||||
@@ -708,6 +708,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.use_jinja = true;
|
||||
}
|
||||
|
||||
params.use_color = tty_can_use_colors();
|
||||
|
||||
// load dynamic backends
|
||||
ggml_backend_load_all();
|
||||
|
||||
@@ -790,10 +792,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||
add_opt(common_arg(
|
||||
{"-co", "--color"},
|
||||
string_format("colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false"),
|
||||
[](common_params & params) {
|
||||
params.use_color = true;
|
||||
{"-co", "--color"}, "[on|off|auto]",
|
||||
"Colorize output to distinguish prompt and user input from generations ('on', 'off', or 'auto', default: 'auto')\n"
|
||||
"'auto' enables colors when output is to a terminal",
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (is_truthy(value)) {
|
||||
params.use_color = true;
|
||||
} else if (is_falsey(value)) {
|
||||
params.use_color = false;
|
||||
} else if (is_autoy(value)) {
|
||||
params.use_color = tty_can_use_colors();
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unknown value for --color: '%s'\n", value.c_str()));
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP}));
|
||||
add_opt(common_arg(
|
||||
@@ -1022,7 +1034,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
|
||||
string_format("error: unknown value for --flash-attn: '%s'\n", value.c_str()));
|
||||
}
|
||||
}).set_env("LLAMA_ARG_FLASH_ATTN"));
|
||||
add_opt(common_arg(
|
||||
@@ -2696,7 +2708,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
|
||||
string_format("error: unknown value for --log-colors: '%s'\n", value.c_str()));
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_LOG_COLORS"));
|
||||
|
||||
@@ -982,6 +982,32 @@ std::vector<common_file_info> fs_list(const std::string & path, bool include_dir
|
||||
return files;
|
||||
}
|
||||
|
||||
//
|
||||
// TTY utils
|
||||
//
|
||||
|
||||
bool tty_can_use_colors() {
|
||||
// Check NO_COLOR environment variable (https://no-color.org/)
|
||||
if (const char * no_color = std::getenv("NO_COLOR")) {
|
||||
if (no_color[0] != '\0') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check TERM environment variable
|
||||
if (const char * term = std::getenv("TERM")) {
|
||||
if (std::strcmp(term, "dumb") == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if stdout and stderr are connected to a terminal
|
||||
// We check both because log messages can go to either
|
||||
bool stdout_is_tty = isatty(fileno(stdout));
|
||||
bool stderr_is_tty = isatty(fileno(stderr));
|
||||
|
||||
return stdout_is_tty || stderr_is_tty;
|
||||
}
|
||||
|
||||
//
|
||||
// Model utils
|
||||
|
||||
@@ -655,6 +655,13 @@ struct common_file_info {
|
||||
};
|
||||
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
|
||||
|
||||
//
|
||||
// TTY utils
|
||||
//
|
||||
|
||||
// Auto-detect if colors can be enabled based on terminal and environment
|
||||
bool tty_can_use_colors();
|
||||
|
||||
//
|
||||
// Model utils
|
||||
//
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <chrono>
|
||||
@@ -26,30 +27,6 @@ void common_log_set_verbosity_thold(int verbosity) {
|
||||
common_log_verbosity_thold = verbosity;
|
||||
}
|
||||
|
||||
// Auto-detect if colors should be enabled based on terminal and environment
|
||||
static bool common_log_should_use_colors_auto() {
|
||||
// Check NO_COLOR environment variable (https://no-color.org/)
|
||||
if (const char * no_color = std::getenv("NO_COLOR")) {
|
||||
if (no_color[0] != '\0') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check TERM environment variable
|
||||
if (const char * term = std::getenv("TERM")) {
|
||||
if (std::strcmp(term, "dumb") == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if stdout and stderr are connected to a terminal
|
||||
// We check both because log messages can go to either
|
||||
bool stdout_is_tty = isatty(fileno(stdout));
|
||||
bool stderr_is_tty = isatty(fileno(stderr));
|
||||
|
||||
return stdout_is_tty || stderr_is_tty;
|
||||
}
|
||||
|
||||
static int64_t t_us() {
|
||||
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
}
|
||||
@@ -391,7 +368,7 @@ struct common_log * common_log_main() {
|
||||
static std::once_flag init_flag;
|
||||
std::call_once(init_flag, [&]() {
|
||||
// Set default to auto-detect colors
|
||||
log.set_colors(common_log_should_use_colors_auto());
|
||||
log.set_colors(tty_can_use_colors());
|
||||
});
|
||||
|
||||
return &log;
|
||||
@@ -422,7 +399,7 @@ void common_log_set_file(struct common_log * log, const char * file) {
|
||||
|
||||
void common_log_set_colors(struct common_log * log, log_colors colors) {
|
||||
if (colors == LOG_COLORS_AUTO) {
|
||||
log->set_colors(common_log_should_use_colors_auto());
|
||||
log->set_colors(tty_can_use_colors());
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
258
docs/backend/ZenDNN.md
Normal file
258
docs/backend/ZenDNN.md
Normal file
@@ -0,0 +1,258 @@
|
||||
# llama.cpp for AMD ZenDNN
|
||||
|
||||
> [!WARNING]
|
||||
> **Note:** ZenDNN is **not** the same as zDNN.
|
||||
> - **ZenDNN** (this page): AMD's deep learning library for AMD EPYC CPUs
|
||||
> - **zDNN**: IBM's Deep Neural Network acceleration library for IBM Z & LinuxONE Mainframes ([see zDNN documentation](zDNN.md))
|
||||
|
||||
- [Background](#background)
|
||||
- [OS](#os)
|
||||
- [Hardware](#hardware)
|
||||
- [Supported Operations](#supported-operations)
|
||||
- [DataType Supports](#datatype-supports)
|
||||
- [Linux](#linux)
|
||||
- [Environment Variable](#environment-variable)
|
||||
- [Performance Optimization](#performance-optimization)
|
||||
- [Known Issues](#known-issues)
|
||||
- [TODO](#todo)
|
||||
|
||||
## Background
|
||||
|
||||
**ZenDNN** (Zen Deep Neural Network Library) is AMD's high-performance deep learning inference library optimized for AMD EPYC™ CPUs. It provides optimized implementations of key deep learning primitives and operations, delivering significant performance improvements for neural network workloads on AMD Zen-based processor architectures.
|
||||
|
||||
**Llama.cpp + ZenDNN**
|
||||
|
||||
The llama.cpp ZenDNN backend leverages AMD's optimized matrix multiplication primitives to accelerate inference on AMD CPUs. It utilizes ZenDNN's **LowOHA (Low Overhead Hardware Accelerated)** MatMul operator for efficient GEMM operations with minimal execution overhead, built-in weight caching, and direct access to backend libraries (AOCL BLIS, LibXSMM, OneDNN).
|
||||
|
||||
For more information about ZenDNN, visit: https://www.amd.com/en/developer/zendnn.html
|
||||
|
||||
## OS
|
||||
|
||||
| OS | Status | Verified |
|
||||
|:-------:|:-------:|:----------------------------------------------:|
|
||||
| Linux | Support | Ubuntu 20.04, 22.04, 24.04 |
|
||||
|
||||
For the latest list of supported operating systems, see the [ZenDNN Supported OS](https://github.com/amd/ZenDNN/blob/zendnnl/README.md#15-supported-os).
|
||||
|
||||
## Hardware
|
||||
|
||||
### AMD CPUs
|
||||
|
||||
**Recommended Processors**
|
||||
|
||||
ZenDNN is optimized for AMD EPYC™ processors and AMD Ryzen™ processors based on "Zen" microarchitecture and newer.
|
||||
|
||||
| CPU Family | Status | Notes |
|
||||
|:-----------------------------:|:-------:|:----------------------------------:|
|
||||
| AMD EPYC™ 9005 Series (Turin)| Support | 5th Gen - Zen 5 architecture |
|
||||
| AMD EPYC™ 9004 Series (Genoa)| Support | 4th Gen - Zen 4 architecture |
|
||||
| AMD EPYC™ 7003 Series (Milan)| Support | 3rd Gen - Zen 3 architecture |
|
||||
| AMD Ryzen™ AI MAX (Strix Halo)| Support | High-performance mobile processors |
|
||||
|
||||
*Notes:*
|
||||
|
||||
- Best performance is achieved on AMD EPYC™ processors with high core counts (e.g., EPYC 9005 series).
|
||||
- ZenDNN leverages AMD's advanced CPU features including AVX2 and AVX-512 instruction sets.
|
||||
- For optimal performance, ensure your system has sufficient memory bandwidth.
|
||||
|
||||
## Supported Operations
|
||||
|
||||
The ZenDNN backend currently accelerates **matrix multiplication (MUL_MAT)** operations only. Other operations are handled by the standard CPU backend.
|
||||
|
||||
| Operation | Status | Notes |
|
||||
|:-------------|:-------:|:----------------------------------------------:|
|
||||
| MUL_MAT | ✓ | Accelerated via ZenDNN LowOHA MatMul |
|
||||
|
||||
*Note:* Since only MUL_MAT is accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs).
|
||||
|
||||
## DataType Supports
|
||||
|
||||
| DataType | Status | Notes |
|
||||
|:----------------------:|:-------:|:---------------------------------------------:|
|
||||
| FP32 | Support | Full precision floating point |
|
||||
| BF16 | Support | BFloat16 (best performance on Zen 4/Zen 5) |
|
||||
|
||||
*Notes:*
|
||||
|
||||
- **BF16** provides best performance on Zen 4 and Zen 5 EPYC™ processors (Genoa, Turin).
|
||||
|
||||
## Linux
|
||||
|
||||
### I. Setup Environment
|
||||
|
||||
You have two options to set up ZenDNN:
|
||||
|
||||
#### Option 1: Automatic Download and Build (Recommended)
|
||||
|
||||
CMake will automatically download and build ZenDNN for you:
|
||||
|
||||
```sh
|
||||
# Build llama.cpp - ZenDNN will be automatically downloaded and built
|
||||
cmake -B build -DGGML_ZENDNN=ON -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
```
|
||||
|
||||
No manual ZenDNN installation required. CMake will handle everything automatically.
|
||||
|
||||
#### Option 2: Use Custom ZenDNN Installation
|
||||
|
||||
If you want to build ZenDNN yourself or use a specific version:
|
||||
|
||||
**Step 1: Build ZenDNN from source**
|
||||
|
||||
```sh
|
||||
# Clone ZenDNN repository
|
||||
git clone https://github.com/amd/ZenDNN.git
|
||||
cd ZenDNN
|
||||
git checkout zendnnl
|
||||
|
||||
# Build and install (requires CMake >= 3.25)
|
||||
mkdir build && cd build
|
||||
cmake ..
|
||||
cmake --build . --target all
|
||||
```
|
||||
|
||||
Default installation path: `ZenDNN/build/install`
|
||||
|
||||
**For detailed build instructions**, refer to the [ZenDNN README](https://github.com/amd/ZenDNN/blob/zendnnl/README.md).
|
||||
|
||||
**Step 2: Build llama.cpp with custom ZenDNN path**
|
||||
|
||||
```sh
|
||||
# Using environment variable
|
||||
export ZENDNN_ROOT=/path/to/ZenDNN/build/install
|
||||
cmake -B build -DGGML_ZENDNN=ON -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
# OR specify path directly in CMake
|
||||
cmake -B build -DGGML_ZENDNN=ON -DZENDNN_ROOT=/path/to/ZenDNN/build/install -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
```
|
||||
|
||||
### II. Run the Server
|
||||
|
||||
#### 1. Download Model
|
||||
|
||||
Download LLaMA 3.1 8B Instruct BF16 model:
|
||||
|
||||
```sh
|
||||
# Download from Hugging Face
|
||||
huggingface-cli download meta-llama/Llama-3.1-8B-Instruct-GGUF --local-dir models/
|
||||
```
|
||||
|
||||
#### 2. Start Server
|
||||
|
||||
Run llama.cpp server with ZenDNN acceleration:
|
||||
|
||||
```sh
|
||||
# Set optimal configuration
|
||||
export OMP_NUM_THREADS=64 # Adjust to your CPU core count
|
||||
export ZENDNNL_MATMUL_ALGO=2 # Blocked AOCL BLIS for best performance
|
||||
|
||||
# Start server
|
||||
./build/bin/llama-server \
|
||||
-m models/Llama-3.1-8B-Instruct.BF16.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-t 64
|
||||
```
|
||||
|
||||
Access the server at `http://localhost:8080`.
|
||||
|
||||
**Performance tips**:
|
||||
- Set `OMP_NUM_THREADS` to match your physical core count
|
||||
- Use `ZENDNNL_MATMUL_ALGO=2` for optimal performance
|
||||
- For NUMA systems: `numactl --cpunodebind=0 --membind=0 ./build/bin/llama-server ...`
|
||||
|
||||
## Environment Variable
|
||||
|
||||
### Build Time
|
||||
|
||||
| Name | Value | Function |
|
||||
|--------------------|---------------------------------------|---------------------------------------------|
|
||||
| GGML_ZENDNN | ON/OFF | Enable ZenDNN backend support |
|
||||
| ZENDNN_ROOT | Path to ZenDNN installation | Set ZenDNN installation directory |
|
||||
| GGML_OPENMP | ON/OFF (recommended: ON) | Enable OpenMP for multi-threading |
|
||||
|
||||
### Runtime
|
||||
|
||||
| Name | Value | Function |
|
||||
|-------------------------|--------------------------|-------------------------------------------------------------------|
|
||||
| OMP_NUM_THREADS | Number (e.g., 64) | Set number of OpenMP threads (recommended: physical core count) |
|
||||
| ZENDNNL_MATMUL_ALGO | 0-5 | Select MatMul backend algorithm (see Performance Optimization) |
|
||||
| ZENDNNL_PROFILE_LOG_LEVEL | 0-4 | Profiling log level (0=disabled, 4=verbose) |
|
||||
| ZENDNNL_ENABLE_PROFILER | 0 or 1 | Enable detailed profiling (1=enabled) |
|
||||
| ZENDNNL_API_LOG_LEVEL | 0-4 | API log level (0=disabled, 4=verbose) |
|
||||
|
||||
**Example**:
|
||||
|
||||
```sh
|
||||
export OMP_NUM_THREADS=64
|
||||
export ZENDNNL_MATMUL_ALGO=2 # Use Blocked AOCL BLIS for best performance
|
||||
./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "Test" -n 100
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### MatMul Algorithm Selection
|
||||
|
||||
ZenDNN's LowOHA MatMul supports multiple backend algorithms. For **best performance**, use the **Blocked AOCL BLIS** algorithm:
|
||||
|
||||
```sh
|
||||
export ZENDNNL_MATMUL_ALGO=2 # Blocked AOCL BLIS (recommended)
|
||||
```
|
||||
|
||||
**Available algorithms**:
|
||||
|
||||
| Value | Algorithm | Description |
|
||||
|:-----:|:-----------------------|:----------------------------------------------|
|
||||
| 0 | Dynamic Dispatch | Automatic backend selection (default) |
|
||||
| 1 | AOCL BLIS | AOCL BLIS backend |
|
||||
| 2 | AOCL BLIS Blocked | **Blocked AOCL BLIS (recommended)** |
|
||||
| 3 | OneDNN | OneDNN backend |
|
||||
| 4 | OneDNN Blocked | Blocked OneDNN |
|
||||
| 5 | LibXSMM | LibXSMM backend |
|
||||
|
||||
### Profiling and Debugging
|
||||
|
||||
For detailed profiling and logging options, refer to the [ZenDNN Logging Documentation](https://github.com/amd/ZenDNN/blob/zendnnl/docs/logging.md).
|
||||
|
||||
## Known Issues
|
||||
|
||||
- **Limited operation support**: Currently only matrix multiplication (MUL_MAT) is accelerated via ZenDNN. Other operations fall back to the standard CPU backend.
|
||||
- **BF16 support**: BF16 operations require AMD Zen 4 or Zen 5 architecture (EPYC 9004/9005 series). On older CPUs, operations will use FP32.
|
||||
- **NUMA awareness**: For multi-socket systems, manual NUMA binding may be required for optimal performance.
|
||||
|
||||
## Q&A
|
||||
|
||||
**Q: How do I verify that ZenDNN backend is being used?**
|
||||
|
||||
A: Check the log output when running llama.cpp. You should see messages indicating the ZenDNN backend is initialized. You can also check the backend name in the output.
|
||||
|
||||
**Q: What performance improvement can I expect?**
|
||||
|
||||
A: Performance gains vary depending on the model size, batch size, and CPU architecture. On AMD EPYC processors, you can typically expect 1.1x-2x speedup compared to standard CPU inference for matrix multiplication operations.
|
||||
|
||||
**Q: Can I use ZenDNN on non-AMD processors?**
|
||||
|
||||
A: ZenDNN is optimized specifically for AMD processors. While it may work on other x86-64 CPUs, performance benefits are only guaranteed on AMD Zen-based architectures.
|
||||
|
||||
**Q: Does ZenDNN support quantized models?**
|
||||
|
||||
A: Currently, ZenDNN primarily supports FP32 and BF16 data types. Quantized model support is not available at this time.
|
||||
|
||||
**Q: Why is my inference not faster with ZenDNN?**
|
||||
|
||||
A: Ensure:
|
||||
1. You're using an AMD EPYC or Ryzen processor (Zen 2 or newer)
|
||||
2. `OMP_NUM_THREADS` is set appropriately (physical core count)
|
||||
3. `ZENDNNL_MATMUL_ALGO=2` is set for best performance (Blocked AOCL BLIS)
|
||||
4. You're using a sufficiently large model (small models may not benefit as much)
|
||||
5. Enable profiling to verify ZenDNN MatMul is being called
|
||||
|
||||
### **GitHub Contribution**:
|
||||
Please add the **[ZenDNN]** prefix/tag in issues/PRs titles to help the ZenDNN-team check/address them without delay.
|
||||
|
||||
## TODO
|
||||
|
||||
- Expand operation support beyond MUL_MAT (attention operations, activations, etc.)
|
||||
@@ -1,5 +1,10 @@
|
||||
# llama.cpp for IBM zDNN Accelerator
|
||||
|
||||
> [!WARNING]
|
||||
> **Note:** zDNN is **not** the same as ZenDNN.
|
||||
> - **zDNN** (this page): IBM's Deep Neural Network acceleration library for IBM Z & LinuxONE Mainframes
|
||||
> - **ZenDNN**: AMD's deep learning library for AMD EPYC CPUs ([see ZenDNN documentation](ZenDNN.md))
|
||||
|
||||
## Background
|
||||
|
||||
IBM zDNN (Z Deep Neural Network) is a hardware acceleration library designed specifically to leverage the IBM NNPA (Neural Network Processor Assist) accelerator located within IBM Telum I and II processors. It provides significant performance improvements for neural network inference operations.
|
||||
|
||||
@@ -495,6 +495,38 @@ llama_new_context_with_model: CANN compute buffer size = 1260.81 MiB
|
||||
|
||||
For detailed info, such as model/device supports, CANN install, please refer to [llama.cpp for CANN](./backend/CANN.md).
|
||||
|
||||
## ZenDNN
|
||||
|
||||
ZenDNN provides optimized deep learning primitives for AMD EPYC™ CPUs. It accelerates matrix multiplication operations for inference workloads.
|
||||
|
||||
### Compilation
|
||||
|
||||
- Using `CMake` on Linux (automatic build):
|
||||
|
||||
```bash
|
||||
cmake -B build -DGGML_ZENDNN=ON
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
The first build will automatically download and build ZenDNN, which may take 5-10 minutes. Subsequent builds will be much faster.
|
||||
|
||||
- Using `CMake` with custom ZenDNN installation:
|
||||
|
||||
```bash
|
||||
cmake -B build -DGGML_ZENDNN=ON -DZENDNN_ROOT=/path/to/zendnn/install
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
You can test with:
|
||||
|
||||
```bash
|
||||
./build/bin/llama-cli -m PATH_TO_MODEL -p "Building a website can be done in 10 steps:" -n 50
|
||||
```
|
||||
|
||||
For detailed information about hardware support, setup instructions, and performance optimization, refer to [llama.cpp for ZenDNN](./backend/ZenDNN.md).
|
||||
|
||||
## Arm® KleidiAI™
|
||||
KleidiAI is a library of optimized microkernels for AI workloads, specifically designed for Arm CPUs. These microkernels enhance performance and can be enabled for use by the CPU backend.
|
||||
|
||||
|
||||
216
docs/ops.md
216
docs/ops.md
@@ -12,111 +12,111 @@ Legend:
|
||||
- 🟡 Partially supported by this backend
|
||||
- ❌ Not supported by this backend
|
||||
|
||||
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | WebGPU | zDNN |
|
||||
|-----------|------|------|------|------|------|------|------|------|------|------|
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| TOP_K | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN |
|
||||
|-----------|------|------|------|------|------|------|------|------|------|------|------|
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TOP_K | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
||||
18741
docs/ops/ZenDNN.csv
Normal file
18741
docs/ops/ZenDNN.csv
Normal file
File diff suppressed because it is too large
Load Diff
@@ -253,6 +253,9 @@ option(GGML_HEXAGON "ggml: enable Hexagon backend"
|
||||
# toolchain for vulkan-shaders-gen
|
||||
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
|
||||
|
||||
option(GGML_ZENDNN "ggml: use ZenDNN" OFF)
|
||||
option(ZENDNN_ROOT "ggml: path to ZenDNN installation" "")
|
||||
|
||||
# extra artifacts
|
||||
option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})
|
||||
option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
|
||||
@@ -314,6 +317,7 @@ set(GGML_PUBLIC_HEADERS
|
||||
include/ggml-sycl.h
|
||||
include/ggml-vulkan.h
|
||||
include/ggml-webgpu.h
|
||||
include/ggml-zendnn.h
|
||||
include/gguf.h)
|
||||
|
||||
set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
|
||||
|
||||
22
ggml/include/ggml-zendnn.h
Normal file
22
ggml/include/ggml-zendnn.h
Normal file
@@ -0,0 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// backend API
|
||||
GGML_BACKEND_API ggml_backend_t ggml_backend_zendnn_init(void);
|
||||
|
||||
GGML_BACKEND_API bool ggml_backend_is_zendnn(ggml_backend_t backend);
|
||||
|
||||
// number of threads used for zendnn operations
|
||||
GGML_BACKEND_API void ggml_backend_zendnn_set_n_threads(ggml_backend_t backend_zendnn, int n_threads);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zendnn_reg(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -2196,6 +2196,15 @@ extern "C" {
|
||||
int p2,
|
||||
int p3);
|
||||
|
||||
// pad each dimension with values on the other side of the torus (looping around)
|
||||
GGML_API struct ggml_tensor * ggml_pad_circular(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int p0,
|
||||
int p1,
|
||||
int p2,
|
||||
int p3);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_pad_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@@ -2209,6 +2218,19 @@ extern "C" {
|
||||
int rp3
|
||||
);
|
||||
|
||||
// pad each dimension with values on the other side of the torus (looping around)
|
||||
GGML_API struct ggml_tensor * ggml_pad_ext_circular(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int lp0,
|
||||
int rp0,
|
||||
int lp1,
|
||||
int rp1,
|
||||
int lp2,
|
||||
int rp2,
|
||||
int lp3,
|
||||
int rp3);
|
||||
|
||||
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
|
||||
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
|
||||
struct ggml_context * ctx,
|
||||
|
||||
@@ -440,6 +440,7 @@ ggml_add_backend(WebGPU)
|
||||
ggml_add_backend(zDNN)
|
||||
ggml_add_backend(OpenCL)
|
||||
ggml_add_backend(Hexagon)
|
||||
ggml_add_backend(ZenDNN)
|
||||
|
||||
foreach (target ggml-base ggml)
|
||||
target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>)
|
||||
|
||||
@@ -73,6 +73,10 @@
|
||||
#include "ggml-cann.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_ZENDNN
|
||||
#include "ggml-zendnn.h"
|
||||
#endif
|
||||
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic push
|
||||
@@ -203,6 +207,9 @@ struct ggml_backend_registry {
|
||||
#ifdef GGML_USE_OPENCL
|
||||
register_backend(ggml_backend_opencl_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_ZENDNN
|
||||
register_backend(ggml_backend_zendnn_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_HEXAGON
|
||||
register_backend(ggml_backend_hexagon_reg());
|
||||
#endif
|
||||
@@ -534,8 +541,12 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
fs::path best_path;
|
||||
|
||||
for (const auto & search_path : search_paths) {
|
||||
if (!fs::exists(search_path)) {
|
||||
GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str());
|
||||
if (std::error_code ec; !fs::exists(search_path, ec)) {
|
||||
if (ec) {
|
||||
GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str());
|
||||
} else {
|
||||
GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
@@ -575,8 +586,12 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
for (const auto & search_path : search_paths) {
|
||||
fs::path filename = backend_filename_prefix().native() + name_path.native() + backend_filename_extension().native();
|
||||
fs::path path = search_path / filename;
|
||||
if (fs::exists(path)) {
|
||||
if (std::error_code ec; fs::exists(path, ec)) {
|
||||
return get_reg().load_backend(path, silent);
|
||||
} else {
|
||||
if (ec) {
|
||||
GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(path).c_str(), ec.message().c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
@@ -597,6 +612,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
#endif
|
||||
|
||||
ggml_backend_load_best("blas", silent, dir_path);
|
||||
ggml_backend_load_best("zendnn", silent, dir_path);
|
||||
ggml_backend_load_best("cann", silent, dir_path);
|
||||
ggml_backend_load_best("cuda", silent, dir_path);
|
||||
ggml_backend_load_best("hip", silent, dir_path);
|
||||
|
||||
@@ -2551,6 +2551,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_PAD:
|
||||
// TODO: add circular padding support for cann, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||
return ggml_get_op_params_i32(op, 8) == 0;
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
|
||||
@@ -6554,8 +6554,13 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
|
||||
ggml_compute_forward_mul_mat(params, &dst);
|
||||
}
|
||||
|
||||
static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
|
||||
return (coord + size) % size; // adding size avoids negative number weirdness
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_2d
|
||||
|
||||
|
||||
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
||||
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
||||
const ggml_tensor * src, // [W, H, C, N]
|
||||
@@ -7591,6 +7596,7 @@ void ggml_compute_forward_upscale(
|
||||
|
||||
// ggml_compute_forward_pad
|
||||
|
||||
template<bool circular_t>
|
||||
static void ggml_compute_forward_pad_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
@@ -7615,23 +7621,40 @@ static void ggml_compute_forward_pad_f32(
|
||||
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
|
||||
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
|
||||
|
||||
|
||||
// TODO: optimize
|
||||
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
||||
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
||||
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
||||
&& (i2 >= lp2 && i2 < ne2 - rp2) \
|
||||
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
||||
// circular means wrap around on a torus, so x and y loop around
|
||||
if constexpr (circular_t) {
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
|
||||
const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
|
||||
const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
|
||||
const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
|
||||
|
||||
const int64_t src_idx =
|
||||
src_i3*nb03 +
|
||||
src_i2*nb02 +
|
||||
src_i1*nb01 +
|
||||
src_i0*nb00;
|
||||
|
||||
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
||||
dst_ptr[dst_idx] = *src_ptr;
|
||||
} else {
|
||||
dst_ptr[dst_idx] = 0;
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
||||
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
||||
&& (i2 >= lp2 && i2 < ne2 - rp2) \
|
||||
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
||||
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
||||
dst_ptr[dst_idx] = *src_ptr;
|
||||
} else {
|
||||
dst_ptr[dst_idx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7639,16 +7662,20 @@ static void ggml_compute_forward_pad_f32(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void ggml_compute_forward_pad(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_pad_f32(params, dst);
|
||||
if (circular) {
|
||||
ggml_compute_forward_pad_f32<true>(params, dst);
|
||||
} else {
|
||||
ggml_compute_forward_pad_f32<false>(params, dst);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
||||
@@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||
case GGML_TYPE_F32:
|
||||
return ampere_mma_available(cc);
|
||||
case GGML_TYPE_F16:
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc));
|
||||
case GGML_TYPE_BF16:
|
||||
return ampere_mma_available(cc) || amd_wmma_available(cc);
|
||||
return ampere_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc));
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
#include "pad.cuh"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
__device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {
|
||||
// + size ensures negatives are handled properly
|
||||
return (coord + size) % size;
|
||||
}
|
||||
|
||||
static __global__ void pad_f32(const float * src, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3) {
|
||||
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||
const bool circular) {
|
||||
// blockIdx.z: i3*ne2+i2
|
||||
// blockIdx.y: i1
|
||||
// blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
|
||||
@@ -12,61 +20,84 @@ static __global__ void pad_f32(const float * src, float * dst,
|
||||
int i1 = blockIdx.y;
|
||||
int i2 = blockIdx.z % ne2;
|
||||
int i3 = blockIdx.z / ne2;
|
||||
|
||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
// operation
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
|
||||
(i1 >= lp1 && i1 < ne1 - rp1) &&
|
||||
(i2 >= lp2 && i2 < ne2 - rp2) &&
|
||||
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t i00 = i0 - lp0;
|
||||
const int64_t i01 = i1 - lp1;
|
||||
const int64_t i02 = i2 - lp2;
|
||||
const int64_t i03 = i3 - lp3;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
const int64_t dst_idx = i3 * (ne0 * ne1 * ne2) + i2 * (ne0 * ne1) + i1 * ne0 + i0;
|
||||
|
||||
const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
|
||||
if (!circular) {
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && (i2 >= lp2 && i2 < ne2 - rp2) &&
|
||||
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t i00 = i0 - lp0;
|
||||
const int64_t i01 = i1 - lp1;
|
||||
const int64_t i02 = i2 - lp2;
|
||||
const int64_t i03 = i3 - lp3;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
|
||||
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
dst[dst_idx] = 0.0f;
|
||||
}
|
||||
}
|
||||
// circular means on a torus, so x and y wrap around
|
||||
else {
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne03 = ne3 - lp3 - rp3;
|
||||
|
||||
const int64_t i00 = wrap_around(i0 - lp0, ne00);
|
||||
const int64_t i01 = wrap_around(i1 - lp1, ne01);
|
||||
const int64_t i02 = wrap_around(i2 - lp2, ne02);
|
||||
const int64_t i03 = wrap_around(i3 - lp3, ne03);
|
||||
|
||||
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
dst[dst_idx] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void pad_f32_cuda(const float * src, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, ne1, ne2*ne3);
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
|
||||
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||
const bool circular, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, ne1, ne2 * ne3);
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst,
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||
ne0, ne1, ne2, ne3, circular);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
|
||||
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
|
||||
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
|
||||
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
|
||||
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
|
||||
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
|
||||
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
|
||||
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
|
||||
const int32_t lp0 = ((const int32_t *) (dst->op_params))[0];
|
||||
const int32_t rp0 = ((const int32_t *) (dst->op_params))[1];
|
||||
const int32_t lp1 = ((const int32_t *) (dst->op_params))[2];
|
||||
const int32_t rp1 = ((const int32_t *) (dst->op_params))[3];
|
||||
const int32_t lp2 = ((const int32_t *) (dst->op_params))[4];
|
||||
const int32_t rp2 = ((const int32_t *) (dst->op_params))[5];
|
||||
const int32_t lp3 = ((const int32_t *) (dst->op_params))[6];
|
||||
const int32_t rp3 = ((const int32_t *) (dst->op_params))[7];
|
||||
const int32_t circular = ((const int32_t *) (dst->op_params))[8];
|
||||
|
||||
pad_f32_cuda(src0_d, dst_d,
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
(bool) circular, stream);
|
||||
}
|
||||
|
||||
@@ -1037,6 +1037,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_POOL_2D:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_PAD:
|
||||
// TODO: add circular padding support for metal, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
|
||||
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
|
||||
@@ -3083,6 +3083,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
case GGML_OP_REPEAT:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
|
||||
case GGML_OP_PAD:
|
||||
// TODO: add circular padding support for opencl, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||
return false;
|
||||
}
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_UPSCALE: {
|
||||
ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);
|
||||
|
||||
@@ -1257,7 +1257,8 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
|
||||
char hash_str[17];
|
||||
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
|
||||
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
||||
if (!fs::exists(cache_file)) {
|
||||
std::error_code ec;
|
||||
if (!fs::exists(cache_file, ec)) {
|
||||
return false;
|
||||
}
|
||||
std::ifstream ifs(cache_file, std::ios::binary);
|
||||
|
||||
@@ -2,6 +2,13 @@
|
||||
#include "dequantize.hpp"
|
||||
#include "presets.hpp"
|
||||
|
||||
#if defined(__INTEL_LLVM_COMPILER)
|
||||
#if __has_include(<sycl/ext/oneapi/bfloat16.hpp>)
|
||||
#include <sycl/ext/oneapi/bfloat16.hpp>
|
||||
#define GGML_SYCL_HAS_BF16
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
@@ -566,6 +573,10 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
||||
return dequantize_row_iq4_nl_sycl;
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_sycl<float>;
|
||||
#ifdef GGML_SYCL_HAS_BF16
|
||||
case GGML_TYPE_BF16:
|
||||
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
|
||||
#endif
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
@@ -627,6 +638,10 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
||||
return dequantize_row_iq4_nl_sycl;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_unary_sycl<sycl::half>;
|
||||
#ifdef GGML_SYCL_HAS_BF16
|
||||
case GGML_TYPE_BF16:
|
||||
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
|
||||
#endif
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
@@ -636,6 +651,10 @@ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_nc_sycl<float>;
|
||||
#ifdef GGML_SYCL_HAS_BF16
|
||||
case GGML_TYPE_BF16:
|
||||
return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>;
|
||||
#endif
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -4613,6 +4613,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_PAD:
|
||||
// TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||
return false;
|
||||
}
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
|
||||
@@ -777,11 +777,6 @@ struct vk_device_struct {
|
||||
std::unique_ptr<vk_memory_logger> memory_logger;
|
||||
#endif
|
||||
|
||||
// for GGML_VK_PERF_LOGGER
|
||||
std::unique_ptr<vk_perf_logger> perf_logger;
|
||||
vk::QueryPool query_pool;
|
||||
int32_t num_queries;
|
||||
|
||||
~vk_device_struct() {
|
||||
VK_LOG_DEBUG("destroy device " << name);
|
||||
|
||||
@@ -1050,6 +1045,7 @@ struct vk_op_pad_push_constants {
|
||||
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
||||
uint32_t misalign_offsets;
|
||||
uint32_t circular;
|
||||
|
||||
uint32_t lp0; uint32_t rp0;
|
||||
uint32_t lp1; uint32_t rp1;
|
||||
@@ -1092,6 +1088,7 @@ static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor
|
||||
p.rp2 = dst->op_params[5];
|
||||
p.lp3 = dst->op_params[6];
|
||||
p.rp3 = dst->op_params[7];
|
||||
p.circular = dst->op_params[8];
|
||||
|
||||
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
|
||||
}
|
||||
@@ -1521,12 +1518,21 @@ private:
|
||||
#define VK_LOG_MEMORY(msg) ((void) 0)
|
||||
#endif // GGML_VULKAN_MEMORY_DEBUG
|
||||
|
||||
static bool vk_perf_logger_enabled = false;
|
||||
// number of calls between perf logger prints
|
||||
static uint32_t vk_perf_logger_frequency = 1;
|
||||
|
||||
class vk_perf_logger {
|
||||
public:
|
||||
void print_timings() {
|
||||
void print_timings(bool force = false) {
|
||||
if (timings.empty()) {
|
||||
return;
|
||||
}
|
||||
print_count++;
|
||||
if ((print_count % vk_perf_logger_frequency) != 0 && !force) {
|
||||
return;
|
||||
}
|
||||
print_count = 0;
|
||||
uint64_t total_all_op_times = 0;
|
||||
std::cerr << "----------------\nVulkan Timings:" << std::endl;
|
||||
for (const auto & t : timings) {
|
||||
@@ -1563,16 +1569,20 @@ class vk_perf_logger {
|
||||
flops.clear();
|
||||
}
|
||||
|
||||
void log_timing(const ggml_tensor * node, uint64_t time) {
|
||||
void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) {
|
||||
std::string fusion_str;
|
||||
if (fusion_name) {
|
||||
fusion_str = fusion_name + std::string(" ");
|
||||
}
|
||||
if (node->op == GGML_OP_UNARY) {
|
||||
timings[ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time);
|
||||
timings[fusion_str + ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time);
|
||||
return;
|
||||
}
|
||||
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
|
||||
const uint64_t m = node->src[0]->ne[1];
|
||||
const uint64_t n = (node->op == GGML_OP_MUL_MAT) ? node->ne[1] : node->ne[2];
|
||||
const uint64_t m = node->ne[0];
|
||||
const uint64_t n = node->ne[1];
|
||||
const uint64_t k = node->src[1]->ne[0];
|
||||
const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3];
|
||||
const uint64_t batch = node->ne[2] * node->ne[3];
|
||||
std::string name = ggml_op_name(node->op);
|
||||
if ((node->op == GGML_OP_MUL_MAT && n <= mul_mat_vec_max_cols) ||
|
||||
(node->op == GGML_OP_MUL_MAT_ID && node->src[2]->ne[1] == 1)) {
|
||||
@@ -1581,9 +1591,13 @@ class vk_perf_logger {
|
||||
name += " ";
|
||||
name += ggml_type_name(node->src[0]->type);
|
||||
name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
|
||||
if (node->op == GGML_OP_MUL_MAT_ID) {
|
||||
name += " n_expert=" + std::to_string(node->src[0]->ne[2]);
|
||||
}
|
||||
if (batch > 1) {
|
||||
name += " batch=" + std::to_string(batch);
|
||||
}
|
||||
name = fusion_str + name;
|
||||
timings[name].push_back(time);
|
||||
flops[name].push_back(m * n * (k + (k - 1)) * batch);
|
||||
return;
|
||||
@@ -1605,6 +1619,7 @@ class vk_perf_logger {
|
||||
uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1));
|
||||
name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
|
||||
", N=N*OW*OH=" + std::to_string(size_N);
|
||||
name = fusion_str + name;
|
||||
flops[name].push_back(n_flops);
|
||||
timings[name].push_back(time);
|
||||
return;
|
||||
@@ -1612,6 +1627,7 @@ class vk_perf_logger {
|
||||
if (node->op == GGML_OP_RMS_NORM) {
|
||||
std::string name = ggml_op_name(node->op);
|
||||
name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
|
||||
name = fusion_str + name;
|
||||
timings[name].push_back(time);
|
||||
return;
|
||||
}
|
||||
@@ -1622,6 +1638,7 @@ class vk_perf_logger {
|
||||
const ggml_tensor * v = node->src[2];
|
||||
const ggml_tensor * m = node->src[3];
|
||||
std::stringstream name;
|
||||
name << fusion_str;
|
||||
name << ggml_op_name(node->op) <<
|
||||
" dst(" << dst->ne[0] << "," << dst->ne[1] << "," << dst->ne[2] << "," << dst->ne[3] << "), " <<
|
||||
" q(" << q->ne[0] << "," << q->ne[1] << "," << q->ne[2] << "," << q->ne[3] << "), " <<
|
||||
@@ -1633,17 +1650,19 @@ class vk_perf_logger {
|
||||
}
|
||||
if (node->op == GGML_OP_TOP_K) {
|
||||
std::stringstream name;
|
||||
name << fusion_str;
|
||||
name << ggml_op_name(node->op) <<
|
||||
" K=" << node->ne[0] <<
|
||||
" (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
|
||||
timings[name.str()].push_back(time);
|
||||
return;
|
||||
}
|
||||
timings[ggml_op_name(node->op)].push_back(time);
|
||||
timings[fusion_str + ggml_op_name(node->op)].push_back(time);
|
||||
}
|
||||
private:
|
||||
std::map<std::string, std::vector<uint64_t>> timings;
|
||||
std::map<std::string, std::vector<uint64_t>> flops;
|
||||
uint32_t print_count {};
|
||||
};
|
||||
|
||||
struct ggml_backend_vk_context {
|
||||
@@ -1697,6 +1716,14 @@ struct ggml_backend_vk_context {
|
||||
// Bit 'i' means nodes[start_of_fusion + i] writes to memory.
|
||||
// If there's no fusion, bit 0 is still set.
|
||||
int fused_ops_write_mask {};
|
||||
|
||||
// for GGML_VK_PERF_LOGGER
|
||||
std::unique_ptr<vk_perf_logger> perf_logger;
|
||||
vk::QueryPool query_pool;
|
||||
std::vector<const char *> query_fusion_names;
|
||||
std::vector<ggml_tensor *> query_nodes;
|
||||
int32_t num_queries {};
|
||||
int32_t query_idx {};
|
||||
};
|
||||
|
||||
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
||||
@@ -1822,8 +1849,6 @@ struct vk_instance_t {
|
||||
static bool vk_instance_initialized = false;
|
||||
static vk_instance_t vk_instance;
|
||||
|
||||
static bool vk_perf_logger_enabled = false;
|
||||
|
||||
#ifdef GGML_VULKAN_CHECK_RESULTS
|
||||
static size_t vk_skip_checks;
|
||||
static size_t vk_output_tensor;
|
||||
@@ -4203,9 +4228,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
||||
device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger());
|
||||
#endif
|
||||
if (vk_perf_logger_enabled) {
|
||||
device->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
|
||||
}
|
||||
|
||||
size_t dev_num = vk_instance.device_indices[idx];
|
||||
|
||||
@@ -5151,6 +5173,11 @@ static void ggml_vk_instance_init() {
|
||||
}
|
||||
|
||||
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
||||
const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY");
|
||||
|
||||
if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) {
|
||||
vk_perf_logger_frequency = std::stoul(GGML_VK_PERF_LOGGER_FREQUENCY);
|
||||
}
|
||||
|
||||
// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
|
||||
VULKAN_HPP_DEFAULT_DISPATCHER.init(vk_instance.instance);
|
||||
@@ -5328,6 +5355,10 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
|
||||
ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue);
|
||||
ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue);
|
||||
|
||||
if (vk_perf_logger_enabled) {
|
||||
ctx->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
|
||||
}
|
||||
|
||||
#ifdef GGML_VULKAN_CHECK_RESULTS
|
||||
const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
|
||||
vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));
|
||||
@@ -12203,6 +12234,9 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
|
||||
|
||||
ctx->compute_cmd_pool.destroy(ctx->device->device);
|
||||
ctx->transfer_cmd_pool.destroy(ctx->device->device);
|
||||
if (vk_perf_logger_enabled) {
|
||||
ctx->perf_logger->print_timings(true);
|
||||
}
|
||||
}
|
||||
|
||||
static int ggml_vk_get_device_count() {
|
||||
@@ -13001,24 +13035,29 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
vk_context compute_ctx;
|
||||
if (vk_perf_logger_enabled) {
|
||||
// allocate/resize the query pool
|
||||
if (ctx->device->num_queries < cgraph->n_nodes + 1) {
|
||||
if (ctx->device->query_pool) {
|
||||
ctx->device->device.destroyQueryPool(ctx->device->query_pool);
|
||||
if (ctx->num_queries < cgraph->n_nodes + 1) {
|
||||
if (ctx->query_pool) {
|
||||
ctx->device->device.destroyQueryPool(ctx->query_pool);
|
||||
}
|
||||
vk::QueryPoolCreateInfo query_create_info;
|
||||
query_create_info.queryType = vk::QueryType::eTimestamp;
|
||||
query_create_info.queryCount = cgraph->n_nodes + 100;
|
||||
ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info);
|
||||
ctx->device->num_queries = query_create_info.queryCount;
|
||||
ctx->query_pool = ctx->device->device.createQueryPool(query_create_info);
|
||||
ctx->num_queries = query_create_info.queryCount;
|
||||
ctx->query_fusion_names.resize(ctx->num_queries);
|
||||
ctx->query_nodes.resize(ctx->num_queries);
|
||||
}
|
||||
|
||||
ctx->device->device.resetQueryPool(ctx->device->query_pool, 0, cgraph->n_nodes+1);
|
||||
ctx->device->device.resetQueryPool(ctx->query_pool, 0, cgraph->n_nodes+1);
|
||||
std::fill(ctx->query_fusion_names.begin(), ctx->query_fusion_names.end(), nullptr);
|
||||
std::fill(ctx->query_nodes.begin(), ctx->query_nodes.end(), nullptr);
|
||||
|
||||
GGML_ASSERT(ctx->compute_ctx.expired());
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
|
||||
ctx->query_idx = 0;
|
||||
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
||||
}
|
||||
|
||||
ctx->prealloc_y_last_pipeline_used = nullptr;
|
||||
@@ -13059,52 +13098,66 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
total_mul_mat_bytes += bytes;
|
||||
}
|
||||
|
||||
const char *fusion_string {};
|
||||
if (!ctx->device->disable_fusion) {
|
||||
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
|
||||
if (num_adds) {
|
||||
ctx->num_additional_fused_ops = num_adds - 1;
|
||||
fusion_string = "MULTI_ADD";
|
||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {
|
||||
ctx->num_additional_fused_ops = 2;
|
||||
fusion_string = "MUL_MAT_ADD_ADD";
|
||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
fusion_string = "MUL_MAT_ADD";
|
||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 2;
|
||||
fusion_string = "MUL_MAT_ID_ADD_ID_MUL";
|
||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
fusion_string = "MUL_MAT_ID_ADD_ID";
|
||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
fusion_string = "MUL_MAT_ID_MUL";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
|
||||
ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
|
||||
ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
|
||||
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {
|
||||
ctx->num_additional_fused_ops = 4;
|
||||
fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS";
|
||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
|
||||
ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {
|
||||
ctx->num_additional_fused_ops = 2;
|
||||
fusion_string = "RMS_NORM_MUL_ROPE";
|
||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
fusion_string = "RMS_NORM_MUL";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
|
||||
ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
|
||||
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
|
||||
ctx->num_additional_fused_ops = 2;
|
||||
fusion_string = "ROPE_VIEW_SET_ROWS";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
|
||||
// view of argsort writes to memory
|
||||
ctx->fused_ops_write_mask |= 1 << 3;
|
||||
fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
|
||||
// view of argsort writes to memory
|
||||
ctx->fused_ops_write_mask |= 1 << 3;
|
||||
fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
|
||||
// view of argsort writes to memory
|
||||
ctx->fused_ops_write_mask |= 1 << 1;
|
||||
fusion_string = "TOPK_MOE_LATE_SOFTMAX";
|
||||
}
|
||||
}
|
||||
ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
|
||||
@@ -13118,7 +13171,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
|
||||
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
|
||||
|
||||
if (vk_perf_logger_enabled) {
|
||||
if (vk_perf_logger_enabled && enqueued) {
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
@@ -13126,10 +13179,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
} else {
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
}
|
||||
// If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
|
||||
for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
|
||||
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
|
||||
}
|
||||
ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
|
||||
ctx->query_fusion_names[ctx->query_idx] = fusion_string;
|
||||
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
||||
}
|
||||
|
||||
if (enqueued) {
|
||||
@@ -13170,14 +13222,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
|
||||
// Get the results and pass them to the logger
|
||||
std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
|
||||
VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results");
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (!ggml_vk_is_empty(cgraph->nodes[i])) {
|
||||
ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod));
|
||||
}
|
||||
VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->query_pool, 0, ctx->query_idx, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results");
|
||||
for (int i = 1; i < ctx->query_idx; i++) {
|
||||
auto node = ctx->query_nodes[i];
|
||||
auto name = ctx->query_fusion_names[i];
|
||||
ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
|
||||
}
|
||||
|
||||
ctx->device->perf_logger->print_timings();
|
||||
ctx->perf_logger->print_timings();
|
||||
}
|
||||
|
||||
if (!ctx->device->support_async) {
|
||||
|
||||
@@ -7,35 +7,85 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
|
||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i,
|
||||
const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
|
||||
// Compute starting index in matrix B for this superblock
|
||||
const uint y_idx = i * QUANT_K + 32 * ib32;
|
||||
|
||||
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
|
||||
|
||||
// Precompute indices for quantization lookup tables
|
||||
const uint qh_base = 2 * ib32;
|
||||
const uint qs_base = 4 * ib32;
|
||||
const uint sc_index = ib32 / 2;
|
||||
const uint sc_shift = 6 * (ib32 & 1);
|
||||
|
||||
// Loop over rows in the superblock
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
// Load per-block scales and shift for quantization
|
||||
const uint16_t[4] scales = data_a[ibi].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
||||
const uint sc = data_a[ibi].scales[sc_index] >> sc_shift;
|
||||
|
||||
const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1));
|
||||
// Temporary caches for decoding
|
||||
FLOAT_TYPE dl_cache[4];
|
||||
uint16_t gvf_cache[4];
|
||||
float delta_cache[4];
|
||||
|
||||
// Precompute the multiplier and lookup values for 4 sub-blocks
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1));
|
||||
const uint qs = data_a[ibi].qs[4 * ib32 + l];
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1);
|
||||
dl_cache[l] = FLOAT_TYPE(d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1));
|
||||
const uint qh = data_a[ibi].qh[qh_base + l / 2] >> (4 * (l & 1));
|
||||
const uint qs = data_a[ibi].qs[qs_base + l];
|
||||
gvf_cache[l] = iq1s_grid[qs | ((qh & 7) << 8)];
|
||||
delta_cache[l] = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
}
|
||||
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
// Loop over columns of the output
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
// Compute base index for matrix B
|
||||
const uint base_b_idx = (j * p.batch_stride_b + b_offset + y_idx) / 4;
|
||||
vec4 b_vals[8];
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
|
||||
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int k = 0; k < 4; ++k) {
|
||||
sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta,
|
||||
fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum));
|
||||
}
|
||||
temp[j][n] = fma(dl, sum, temp[j][n]);
|
||||
// Load 8 vec4 values from matrix B
|
||||
[[unroll]] for (int idx = 0; idx < 8; ++idx) {
|
||||
b_vals[idx] = vec4(data_b_v4[base_b_idx + idx]);
|
||||
}
|
||||
|
||||
FLOAT_TYPE col_sum = FLOAT_TYPE(0.0);
|
||||
|
||||
// Loop over sub-blocks
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const uint16_t grid = gvf_cache[l];
|
||||
const float dl = dl_cache[l];
|
||||
|
||||
// Decode 8 2-bit fbits from gvf_cache
|
||||
float f0 = float(bitfieldExtract(grid, 0, 2));
|
||||
float f1 = float(bitfieldExtract(grid, 2, 2));
|
||||
float f2 = float(bitfieldExtract(grid, 4, 2));
|
||||
float f3 = float(bitfieldExtract(grid, 6, 2));
|
||||
float f4 = float(bitfieldExtract(grid, 8, 2));
|
||||
float f5 = float(bitfieldExtract(grid, 10, 2));
|
||||
float f6 = float(bitfieldExtract(grid, 12, 2));
|
||||
float f7 = float(bitfieldExtract(grid, 14, 2));
|
||||
|
||||
// Pack into vec4 for vectorized FMA
|
||||
const vec4 fbits_v0 = vec4(f0, f1, f2, f3);
|
||||
const vec4 fbits_v1 = vec4(f4, f5, f6, f7);
|
||||
const vec4 delta_v = vec4(delta_cache[l]);
|
||||
|
||||
// Vectorized fused multiply-add
|
||||
vec4 sum_v = fma(b_vals[2*l + 0], fbits_v0 + delta_v, vec4(0.0));
|
||||
sum_v = fma(b_vals[2*l + 1], fbits_v1 + delta_v, sum_v);
|
||||
|
||||
// Horizontal add to get scalar sum
|
||||
FLOAT_TYPE sum = sum_v.x + sum_v.y + sum_v.z + sum_v.w;
|
||||
|
||||
// Accumulate to column sum
|
||||
col_sum = fma(dl, sum, col_sum);
|
||||
}
|
||||
// Write result to temporary buffer
|
||||
temp[j][n] += col_sum;
|
||||
}
|
||||
ibi += num_blocks_per_row;
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ layout (push_constant) uniform parameter
|
||||
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
|
||||
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
||||
uint misalign_offsets;
|
||||
uint circular;
|
||||
|
||||
uint lp0; uint rp0;
|
||||
uint lp1; uint rp1;
|
||||
@@ -18,6 +19,10 @@ layout (push_constant) uniform parameter
|
||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
uint wrap_around(int coord, uint size) {
|
||||
return (uint(coord + int(size))) % size; // add size to avoid issues with negative
|
||||
}
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
@@ -40,10 +45,20 @@ void main() {
|
||||
const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00;
|
||||
const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
|
||||
|
||||
const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
|
||||
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
|
||||
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
|
||||
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;
|
||||
if (p.circular != 0u) {
|
||||
const uint ci0 = wrap_around(int(i0) - int(p.lp0), p.ne00);
|
||||
const uint ci1 = wrap_around(int(i1) - int(p.lp1), p.ne01);
|
||||
const uint ci2 = wrap_around(int(i2) - int(p.lp2), p.ne02);
|
||||
const uint ci3 = wrap_around(int(i3) - int(p.lp3), p.ne03);
|
||||
const uint circular_src_idx = ci3*p.nb03 + ci2*p.nb02 + ci1*p.nb01 + ci0*p.nb00;
|
||||
data_d[get_doffset() + dst_idx] = D_TYPE(data_a[get_aoffset() + circular_src_idx]);
|
||||
} else {
|
||||
const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
|
||||
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
|
||||
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
|
||||
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;
|
||||
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
|
||||
}
|
||||
|
||||
|
||||
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
|
||||
}
|
||||
|
||||
92
ggml/src/ggml-zendnn/CMakeLists.txt
Normal file
92
ggml/src/ggml-zendnn/CMakeLists.txt
Normal file
@@ -0,0 +1,92 @@
|
||||
ggml_add_backend_library(ggml-zendnn
|
||||
ggml-zendnn.cpp)
|
||||
|
||||
# Get ZenDNN path
|
||||
if (NOT DEFINED ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "")
|
||||
set(ZENDNN_ROOT "$ENV{ZENDNN_ROOT}")
|
||||
endif()
|
||||
|
||||
# Check if path is still empty or OFF
|
||||
if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
|
||||
message(STATUS "ZENDNN_ROOT not set. Automatically downloading and building ZenDNN...")
|
||||
message(STATUS "This will take several minutes on first build...")
|
||||
|
||||
include(ExternalProject)
|
||||
|
||||
set(ZENDNN_PREFIX ${CMAKE_BINARY_DIR}/_deps/zendnn-prefix)
|
||||
set(ZENDNN_SOURCE_DIR ${ZENDNN_PREFIX}/src/zendnn)
|
||||
set(ZENDNN_BUILD_DIR ${ZENDNN_PREFIX}/build)
|
||||
set(ZENDNN_INSTALL_DIR ${ZENDNN_BUILD_DIR}/install)
|
||||
|
||||
ExternalProject_Add(
|
||||
zendnn
|
||||
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
|
||||
GIT_TAG zendnnl
|
||||
PREFIX ${ZENDNN_PREFIX}
|
||||
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
|
||||
BINARY_DIR ${ZENDNN_BUILD_DIR}
|
||||
CMAKE_ARGS
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_INSTALL_PREFIX=${ZENDNN_INSTALL_DIR}
|
||||
-DZENDNNL_BUILD_EXAMPLES=OFF
|
||||
-DZENDNNL_BUILD_DOXYGEN=OFF
|
||||
-DZENDNNL_BUILD_GTEST=OFF
|
||||
-DZENDNNL_BUILD_BENCHDNN=OFF
|
||||
# Enable ALL matmul algorithm backends
|
||||
-DZENDNNL_DEPENDS_AOCLDLP=ON
|
||||
-DZENDNNL_DEPENDS_ONEDNN=ON
|
||||
-DZENDNNL_DEPENDS_LIBXSMM=ON
|
||||
BUILD_COMMAND ${CMAKE_COMMAND} --build ${ZENDNN_BUILD_DIR} --target zendnnl
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} --build ${ZENDNN_BUILD_DIR} --target install
|
||||
BUILD_ALWAYS OFF
|
||||
LOG_DOWNLOAD ON
|
||||
LOG_CONFIGURE ON
|
||||
LOG_BUILD ON
|
||||
LOG_INSTALL ON
|
||||
)
|
||||
|
||||
# Add dependency so ZenDNN builds before our library
|
||||
add_dependencies(ggml-zendnn zendnn)
|
||||
|
||||
# Set ZENDNN_ROOT to the installation directory
|
||||
set(ZENDNN_ROOT ${ZENDNN_INSTALL_DIR})
|
||||
|
||||
message(STATUS "ZenDNN will be built to: ${ZENDNN_ROOT}")
|
||||
else()
|
||||
message(STATUS "Using custom ZenDNN installation at: ${ZENDNN_ROOT}")
|
||||
endif()
|
||||
|
||||
# ZenDNN headers + libs
|
||||
target_include_directories(ggml-zendnn PRIVATE
|
||||
${ZENDNN_ROOT}/zendnnl/include
|
||||
${ZENDNN_ROOT}/deps/aocldlp/include
|
||||
${ZENDNN_ROOT}/deps/aoclutils/include
|
||||
${ZENDNN_ROOT}/deps/json/include
|
||||
${ZENDNN_ROOT}/deps/libxsmm/include
|
||||
${ZENDNN_ROOT}/deps/onednn/include
|
||||
)
|
||||
|
||||
target_link_directories(ggml-zendnn PRIVATE
|
||||
${ZENDNN_ROOT}/zendnnl/lib
|
||||
${ZENDNN_ROOT}/deps/aocldlp/lib
|
||||
${ZENDNN_ROOT}/deps/aoclutils/lib
|
||||
${ZENDNN_ROOT}/deps/libxsmm/lib
|
||||
${ZENDNN_ROOT}/deps/onednn/lib
|
||||
)
|
||||
|
||||
target_link_libraries(ggml-zendnn PRIVATE
|
||||
zendnnl_archive # ZenDNN main
|
||||
aocl-dlp # AOCL libraries
|
||||
aoclutils
|
||||
au_cpuid
|
||||
dnnl # OneDNN
|
||||
xsmm # libxsmm small matrix math
|
||||
xsmmext
|
||||
xsmmnoblas
|
||||
m
|
||||
pthread
|
||||
)
|
||||
|
||||
if (GGML_OPENMP)
|
||||
target_link_libraries(ggml-zendnn PRIVATE OpenMP::OpenMP_CXX)
|
||||
endif()
|
||||
466
ggml/src/ggml-zendnn/ggml-zendnn.cpp
Normal file
466
ggml/src/ggml-zendnn/ggml-zendnn.cpp
Normal file
@@ -0,0 +1,466 @@
|
||||
#include "ggml-zendnn.h"
|
||||
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "zendnnl.hpp"
|
||||
|
||||
#include <cstring>
|
||||
|
||||
|
||||
struct ggml_backend_zendnn_context {
|
||||
int n_threads = GGML_DEFAULT_N_THREADS;
|
||||
std::unique_ptr<char[]> work_data;
|
||||
size_t work_size = 0;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
zendnnl::common::data_type_t ggml_to_zendnn_type() {
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
return zendnnl::common::data_type_t::f32;
|
||||
} else if constexpr (std::is_same_v<T, ggml_bf16_t>) {
|
||||
return zendnnl::common::data_type_t::bf16;
|
||||
} else {
|
||||
return zendnnl::common::data_type_t::none;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* ZenDNN matmul: computes C = B * A.
|
||||
*
|
||||
* - A: weights, shape (k, m), column-major (each column is a weight vector for one output).
|
||||
* - B: input, shape (n, k), row-major (each row is an input sample).
|
||||
* - C: output, shape (n, m), row-major.
|
||||
*
|
||||
* Dimensions:
|
||||
* m = output features (columns of C, columns of A)
|
||||
* n = batch size (rows of C, rows of B)
|
||||
* k = inner dimension (columns of B, rows of A)
|
||||
*/
|
||||
template <typename TA, typename TB, typename TC>
|
||||
static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k,
|
||||
const TA * A, int64_t lda, const TB * B, int64_t ldb, TC * C,
|
||||
int64_t ldc) {
|
||||
|
||||
zendnnl::lowoha::lowoha_params params;
|
||||
params.dtypes.src = ggml_to_zendnn_type<TB>();
|
||||
params.dtypes.wei = ggml_to_zendnn_type<TA>();
|
||||
params.dtypes.dst = ggml_to_zendnn_type<TC>();
|
||||
params.num_threads = ctx->n_threads;
|
||||
|
||||
zendnnl::lowoha::status_t status = zendnnl::lowoha::matmul_direct(
|
||||
'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major)
|
||||
n, // M: rows of B and C
|
||||
m, // N: cols of A^T and C
|
||||
k, // K: cols of B, rows of A
|
||||
1.0f, // alpha
|
||||
B, ldb, // src: B[n,k]
|
||||
A, lda, // weight: A[k,m] column-major (transposed)
|
||||
nullptr, // bias
|
||||
0.0f, // beta
|
||||
C, ldc, // output C[n,m]
|
||||
true, // is_weights_const
|
||||
{}, // batch_params
|
||||
params // params
|
||||
);
|
||||
|
||||
if (status != zendnnl::lowoha::status_t::success) {
|
||||
GGML_LOG_ERROR("%s, ZenDNN matmul failed: status=%d\n", __func__, static_cast<int>(status));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k,
|
||||
const void * A, int64_t lda, const void * B, int64_t ldb, void * C,
|
||||
int64_t ldc, int Atype, int Btype, int Ctype) {
|
||||
|
||||
assert(m >= 0);
|
||||
assert(n >= 0);
|
||||
assert(k >= 0);
|
||||
assert(lda >= k);
|
||||
assert(ldb >= k);
|
||||
assert(ldc >= m);
|
||||
|
||||
// categorize types
|
||||
switch (Atype) {
|
||||
case GGML_TYPE_F32:
|
||||
if (Btype != GGML_TYPE_F32 || Ctype != GGML_TYPE_F32)
|
||||
return false;
|
||||
return ggml_zendnn_matmul<float, float, float>(
|
||||
ctx, m, n, k,
|
||||
(const float *)A, lda,
|
||||
(const float *)B, ldb,
|
||||
(float *)C, ldc);
|
||||
case GGML_TYPE_BF16:
|
||||
if (Btype != GGML_TYPE_BF16)
|
||||
return false;
|
||||
if (Ctype == GGML_TYPE_BF16)
|
||||
return ggml_zendnn_matmul<ggml_bf16_t, ggml_bf16_t, ggml_bf16_t>(
|
||||
ctx, m, n, k,
|
||||
(const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(ggml_bf16_t *)C, ldc);
|
||||
if (Ctype == GGML_TYPE_F32)
|
||||
return ggml_zendnn_matmul<ggml_bf16_t, ggml_bf16_t, float>(
|
||||
ctx, m, n, k,
|
||||
(const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc);
|
||||
return false;
|
||||
default:
|
||||
return false; // unsupported type
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_zendnn_compute_forward_mul_mat(
|
||||
ggml_backend_zendnn_context * ctx,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0]; // weights
|
||||
const ggml_tensor * src1 = dst->src[1]; // inputs
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
ggml_type const vec_dot_type = ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
|
||||
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(vec_dot_type)->from_float;
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
GGML_ASSERT(ne2 == ne12);
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t r2 = ne12/ne02;
|
||||
const int64_t r3 = ne13/ne03;
|
||||
|
||||
void * work_data = ctx->work_data.get();
|
||||
if (src1->type != vec_dot_type) {
|
||||
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
|
||||
const size_t nbw2 = nbw1 * ne11;
|
||||
const size_t nbw3 = nbw2 * ne12;
|
||||
const size_t desired_wsize = ne13 * nbw3;
|
||||
if (ctx->work_size < desired_wsize) {
|
||||
ctx->work_data.reset(new char[desired_wsize]);
|
||||
ctx->work_size = desired_wsize;
|
||||
}
|
||||
work_data = ctx->work_data.get();
|
||||
|
||||
// #pragma omp parallel for num_threads(ctx->n_threads)
|
||||
#pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static)
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
const float * src1_f32 = (float *)((char *)src1->data + i11*nb11 + i12*nb12 + i13*nb13);
|
||||
void * src1_conv = (char *)work_data + i11*nbw1 + i12*nbw2 + i13*nbw3;
|
||||
from_float(src1_f32, src1_conv, ne10);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
const void* wdata = src1->type == vec_dot_type ? src1->data : work_data;
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
if (!ggml_zendnn_sgemm(ctx,
|
||||
ne01, // m
|
||||
ne11, // n
|
||||
ne10, // k
|
||||
static_cast<const char *>(src0->data) + (i12/r2)*nb02 + (i13/r3)*nb03,
|
||||
ne00, // lda
|
||||
static_cast<const char *>(wdata) + (i12*ne11 + i13*ne12*ne11)*row_size,
|
||||
ne10, // ldb
|
||||
static_cast<char *>(dst->data) + i12*nb2 + i13*nb3,
|
||||
ne01, // ldc
|
||||
src0->type,
|
||||
vec_dot_type,
|
||||
dst->type))
|
||||
GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// backend interface
|
||||
|
||||
static const char * ggml_backend_zendnn_get_name(ggml_backend_t backend) {
|
||||
return "ZenDNN";
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static void ggml_backend_zendnn_free(ggml_backend_t backend) {
|
||||
ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend->context;
|
||||
delete ctx;
|
||||
delete backend;
|
||||
}
|
||||
|
||||
static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend->context;
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
struct ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
switch (node->op) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
ggml_zendnn_compute_forward_mul_mat(ctx, node);
|
||||
break;
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
break;
|
||||
|
||||
default:
|
||||
GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node));
|
||||
}
|
||||
}
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static struct ggml_backend_i ggml_backend_zendnn_i = {
|
||||
/* .get_name = */ ggml_backend_zendnn_get_name,
|
||||
/* .free = */ ggml_backend_zendnn_free,
|
||||
/* .set_tensor_async = */ NULL,
|
||||
/* .get_tensor_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ NULL,
|
||||
/* .synchronize = */ NULL,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
/* .graph_plan_free = */ NULL,
|
||||
/* .graph_plan_update = */ NULL,
|
||||
/* .graph_plan_compute = */ NULL,
|
||||
/* .graph_compute = */ ggml_backend_zendnn_graph_compute,
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_zendnn_guid(void) {
|
||||
static const char * guid_str = "AMD-ZENDNN-ACCEL";
|
||||
return reinterpret_cast<ggml_guid_t>(const_cast<char*>(guid_str));
|
||||
}
|
||||
|
||||
ggml_backend_t ggml_backend_zendnn_init(void) {
|
||||
ggml_backend_zendnn_context * ctx = new ggml_backend_zendnn_context;
|
||||
|
||||
ggml_backend_t backend = new ggml_backend {
|
||||
/* .guid = */ ggml_backend_zendnn_guid(),
|
||||
/* .iface = */ ggml_backend_zendnn_i,
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_zendnn_reg(), 0),
|
||||
/* .context = */ ctx,
|
||||
};
|
||||
|
||||
return backend;
|
||||
}
|
||||
|
||||
bool ggml_backend_is_zendnn(ggml_backend_t backend) {
|
||||
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_zendnn_guid());
|
||||
}
|
||||
|
||||
void ggml_backend_zendnn_set_n_threads(ggml_backend_t backend_zendnn, int n_threads) {
|
||||
GGML_ASSERT(ggml_backend_is_zendnn(backend_zendnn));
|
||||
|
||||
ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend_zendnn->context;
|
||||
ctx->n_threads = n_threads;
|
||||
}
|
||||
|
||||
// device interface
|
||||
static const char * ggml_backend_zendnn_device_get_name(ggml_backend_dev_t dev) {
|
||||
return "ZenDNN";
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
/**
|
||||
* ZenDNN is AMD's performance library providing optimized primitives and implementations
|
||||
* for deep learning workloads on AMD CPUs. It targets improved performance for common
|
||||
* neural network operations on AMD architectures. For more information, see:
|
||||
* https://www.amd.com/en/developer/zendnn.html
|
||||
*/
|
||||
static const char * ggml_backend_zendnn_device_get_description(ggml_backend_dev_t dev) {
|
||||
return "ZenDNN: AMD optimized primitives backend for GGML (optimized for AMD CPUs)";
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static void ggml_backend_zendnn_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
*free = 0;
|
||||
*total = 0;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_zendnn_device_get_type(ggml_backend_dev_t dev) {
|
||||
return GGML_BACKEND_DEVICE_TYPE_ACCEL;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static void ggml_backend_zendnn_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_zendnn_device_get_name(dev);
|
||||
props->description = ggml_backend_zendnn_device_get_description(dev);
|
||||
props->type = ggml_backend_zendnn_device_get_type(dev);
|
||||
ggml_backend_zendnn_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
props->caps = {
|
||||
/* .async = */ false,
|
||||
/* .host_buffer = */ false,
|
||||
/* .buffer_from_host_ptr = */ true,
|
||||
/* .events = */ false
|
||||
};
|
||||
}
|
||||
|
||||
static ggml_backend_t ggml_backend_zendnn_device_init_backend(ggml_backend_dev_t dev, const char * params) {
|
||||
ggml_backend_t backend = ggml_backend_zendnn_init();
|
||||
if (backend == NULL) {
|
||||
GGML_LOG_ERROR("%s: error: failed to initialize ZenDNN backend\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return backend;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
GGML_UNUSED(params);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_zendnn_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||
return ggml_backend_cpu_buffer_type();
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_zendnn_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
||||
return ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
GGML_UNUSED(max_tensor_size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
||||
switch (op->op) {
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
return true;
|
||||
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
const ggml_tensor * weights = op->src[0];
|
||||
const ggml_tensor * inputs = op->src[1];
|
||||
|
||||
const int64_t ne10 = inputs->ne[0];
|
||||
const int64_t ne0 = op->ne[0];
|
||||
const int64_t ne1 = op->ne[1];
|
||||
|
||||
const int64_t min_batch = 1;
|
||||
if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) ||
|
||||
ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
|
||||
return false;
|
||||
}
|
||||
switch (weights->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_BF16:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static bool ggml_backend_zendnn_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||
return ggml_backend_buft_is_host(buft);
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static const struct ggml_backend_device_i ggml_backend_zendnn_device_i = {
|
||||
/* .get_name = */ ggml_backend_zendnn_device_get_name,
|
||||
/* .get_description = */ ggml_backend_zendnn_device_get_description,
|
||||
/* .get_memory = */ ggml_backend_zendnn_device_get_memory,
|
||||
/* .get_type = */ ggml_backend_zendnn_device_get_type,
|
||||
/* .get_props = */ ggml_backend_zendnn_device_get_props,
|
||||
/* .init_backend = */ ggml_backend_zendnn_device_init_backend,
|
||||
/* .get_buffer_type = */ ggml_backend_zendnn_device_get_buffer_type,
|
||||
/* .get_host_buffer_type = */ NULL,
|
||||
/* .buffer_from_host_ptr = */ ggml_backend_zendnn_device_buffer_from_host_ptr,
|
||||
/* .supports_op = */ ggml_backend_zendnn_device_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_zendnn_device_supports_buft,
|
||||
/* .offload_op = */ NULL,
|
||||
/* .event_new = */ NULL,
|
||||
/* .event_free = */ NULL,
|
||||
/* .event_synchronize = */ NULL,
|
||||
};
|
||||
|
||||
// backend reg interface
|
||||
static const char * ggml_backend_zendnn_reg_get_name(ggml_backend_reg_t reg) {
|
||||
return "ZenDNN";
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_zendnn_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||
return 1;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static ggml_backend_dev_t ggml_backend_zendnn_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
||||
GGML_ASSERT(index == 0);
|
||||
|
||||
static ggml_backend_device ggml_backend_zendnn_device = {
|
||||
/* .iface = */ ggml_backend_zendnn_device_i,
|
||||
/* .reg = */ reg,
|
||||
/* .context = */ nullptr,
|
||||
};
|
||||
|
||||
return &ggml_backend_zendnn_device;
|
||||
}
|
||||
|
||||
static void * ggml_backend_zendnn_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
||||
if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
|
||||
return (void *) ggml_backend_zendnn_set_n_threads;
|
||||
}
|
||||
return NULL;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
GGML_UNUSED(name);
|
||||
}
|
||||
|
||||
static const struct ggml_backend_reg_i ggml_backend_zendnn_reg_i = {
|
||||
/* .get_name = */ ggml_backend_zendnn_reg_get_name,
|
||||
/* .get_device_count = */ ggml_backend_zendnn_reg_get_device_count,
|
||||
/* .get_device = */ ggml_backend_zendnn_reg_get_device,
|
||||
/* .get_proc_address = */ ggml_backend_zendnn_get_proc_address,
|
||||
};
|
||||
|
||||
ggml_backend_reg_t ggml_backend_zendnn_reg(void) {
|
||||
static struct ggml_backend_reg ggml_backend_zendnn_reg = {
|
||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||
/* .iface = */ ggml_backend_zendnn_reg_i,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
return &ggml_backend_zendnn_reg;
|
||||
}
|
||||
|
||||
GGML_BACKEND_DL_IMPL(ggml_backend_zendnn_reg)
|
||||
@@ -4947,6 +4947,18 @@ struct ggml_tensor * ggml_pad(
|
||||
return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
|
||||
}
|
||||
|
||||
// ggml_pad_circular
|
||||
|
||||
struct ggml_tensor * ggml_pad_circular(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int p0,
|
||||
int p1,
|
||||
int p2,
|
||||
int p3) {
|
||||
return ggml_pad_ext_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_pad_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@@ -4973,6 +4985,7 @@ struct ggml_tensor * ggml_pad_ext(
|
||||
ggml_set_op_params_i32(result, 5, rp2);
|
||||
ggml_set_op_params_i32(result, 6, lp3);
|
||||
ggml_set_op_params_i32(result, 7, rp3);
|
||||
ggml_set_op_params_i32(result, 8, 0); // not circular by default
|
||||
|
||||
|
||||
result->op = GGML_OP_PAD;
|
||||
@@ -4981,6 +4994,25 @@ struct ggml_tensor * ggml_pad_ext(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_pad_ext_circular
|
||||
|
||||
struct ggml_tensor * ggml_pad_ext_circular(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int lp0,
|
||||
int rp0,
|
||||
int lp1,
|
||||
int rp1,
|
||||
int lp2,
|
||||
int rp2,
|
||||
int lp3,
|
||||
int rp3
|
||||
) {
|
||||
struct ggml_tensor * result = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||
ggml_set_op_params_i32(result, 8, 1); // circular
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_pad_reflect_1d
|
||||
|
||||
struct ggml_tensor * ggml_pad_reflect_1d(
|
||||
|
||||
@@ -1628,6 +1628,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
}
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
|
||||
|
||||
// (optional) temperature tuning - used by mistral-large
|
||||
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 27: type = LLM_TYPE_16B; break;
|
||||
case 60: type = LLM_TYPE_236B; break;
|
||||
|
||||
@@ -666,7 +666,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
|
||||
std::map<int, std::string> mapped;
|
||||
int blk_id = 0;
|
||||
int pruned_attention_w = 0;
|
||||
|
||||
// make a list of weights
|
||||
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
|
||||
@@ -674,11 +673,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
for (const auto & it : ml.weights_map) {
|
||||
const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
|
||||
if (remapped_name.empty()) {
|
||||
if (it.first.find("attn_v.weight") != std::string::npos ||
|
||||
it.first.find("attn_qkv.weight") != std::string::npos ||
|
||||
it.first.find("attn_kv_b.weight") != std::string::npos) {
|
||||
pruned_attention_w++;
|
||||
}
|
||||
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
|
||||
continue;
|
||||
}
|
||||
@@ -703,7 +697,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
});
|
||||
}
|
||||
|
||||
bool is_clip_model = false;
|
||||
for (const auto * it : tensors) {
|
||||
const struct ggml_tensor * tensor = it->tensor;
|
||||
|
||||
@@ -717,30 +710,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
|
||||
qs.has_output = true;
|
||||
}
|
||||
|
||||
is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix
|
||||
}
|
||||
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
||||
|
||||
// sanity checks for models that have attention layers
|
||||
if (qs.n_attention_wv != 0 && !is_clip_model)
|
||||
{
|
||||
int32_t n_layer_all = model.hparams.n_layer;
|
||||
if (llama_model_has_encoder(&model)) {
|
||||
// now n_layer_all is the number of attention layers in the encoder
|
||||
// for each decoder block, there are 2 attention layers
|
||||
n_layer_all += 2 * model.hparams.dec_n_layer;
|
||||
}
|
||||
|
||||
// note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers
|
||||
const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true);
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_layer_all = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_all, n_layer_recr, pruned_attention_w);
|
||||
|
||||
GGML_ASSERT((qs.n_attention_wv == n_layer_all - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
|
||||
}
|
||||
|
||||
size_t total_size_org = 0;
|
||||
size_t total_size_new = 0;
|
||||
|
||||
|
||||
@@ -30,6 +30,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
// {n_embd, n_tokens}
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// (optional) temperature tuning - used by mistral-large
|
||||
ggml_tensor * inp_attn_scale = nullptr;
|
||||
if (hparams.f_attn_temp_scale != 0.0f) {
|
||||
inp_attn_scale = build_inp_attn_scale();
|
||||
}
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
@@ -128,6 +134,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
ggml_tensor * Vcur = kv_cmpr;
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
if (inp_attn_scale) {
|
||||
// apply llama 4 temperature scaling
|
||||
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
|
||||
cb(Qcur, "Qcur_attn_temp_scaled", il);
|
||||
}
|
||||
|
||||
// note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
@@ -160,6 +172,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
if (inp_attn_scale) {
|
||||
// apply llama 4 temperature scaling
|
||||
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
|
||||
cb(Qcur, "Qcur_attn_temp_scaled", il);
|
||||
}
|
||||
|
||||
// note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
|
||||
@@ -5604,21 +5604,24 @@ struct test_pad : public test_case {
|
||||
const std::array<int64_t, 4> ne_a;
|
||||
const int pad_0;
|
||||
const int pad_1;
|
||||
const bool circular;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
|
||||
return VARS_TO_STR5(type, ne_a, pad_0, pad_1, circular);
|
||||
}
|
||||
|
||||
test_pad(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_a = {512, 512, 1, 1},
|
||||
int pad_0 = 1, int pad_1 = 1)
|
||||
: type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1) {}
|
||||
int pad_0 = 1, int pad_1 = 1, bool circular = false)
|
||||
: type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1), circular(circular) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
|
||||
ggml_tensor * out = circular
|
||||
? ggml_pad_circular(ctx, a, pad_0, pad_1, 0, 0)
|
||||
: ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
@@ -5638,17 +5641,19 @@ struct test_pad_ext : public test_case {
|
||||
const int lp3;
|
||||
const int rp3;
|
||||
const bool v;
|
||||
const bool circular;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR11(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v);
|
||||
return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v, circular);
|
||||
}
|
||||
|
||||
test_pad_ext(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_a = {512, 512, 3, 1},
|
||||
int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1,
|
||||
int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1,
|
||||
bool v = false)
|
||||
: type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3), v(v) {}
|
||||
bool v = false, bool circular = false)
|
||||
: type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3),
|
||||
v(v), circular(circular) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
@@ -5659,7 +5664,9 @@ struct test_pad_ext : public test_case {
|
||||
ggml_set_name(a, "view of a");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||
ggml_tensor * out = circular
|
||||
? ggml_pad_ext_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3)
|
||||
: ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
@@ -7782,6 +7789,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1}));
|
||||
test_cases.emplace_back(new test_acc());
|
||||
test_cases.emplace_back(new test_pad());
|
||||
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {33, 17, 2, 1}, 4, 3, true)); // circular
|
||||
test_cases.emplace_back(new test_pad_ext());
|
||||
test_cases.emplace_back(new test_pad_reflect_1d());
|
||||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
|
||||
@@ -7829,8 +7837,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 300, 64, 4, 4 }));
|
||||
|
||||
for (bool v : {false, true}) {
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v));
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
|
||||
for (bool circular : {false, true}) {
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v, circular));
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v, circular));
|
||||
}
|
||||
}
|
||||
|
||||
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
|
||||
|
||||
@@ -1737,7 +1737,8 @@ struct markdown_printer : public printer {
|
||||
fields.emplace_back("params");
|
||||
fields.emplace_back("backend");
|
||||
bool is_cpu_backend = test::get_backend().find("CPU") != std::string::npos ||
|
||||
test::get_backend().find("BLAS") != std::string::npos;
|
||||
test::get_backend().find("BLAS") != std::string::npos ||
|
||||
test::get_backend().find("ZenDNN") != std::string::npos;
|
||||
if (!is_cpu_backend) {
|
||||
fields.emplace_back("n_gpu_layers");
|
||||
}
|
||||
|
||||
@@ -493,6 +493,8 @@ Note for `multimodal_data` in JSON object prompts. This should be an array of st
|
||||
`n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. The number excludes the BOS token.
|
||||
By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to retain all tokens from the prompt.
|
||||
|
||||
`n_cmpl`: Number of completions to generate from the current prompt. If input has multiple prompts, the output will have N prompts times `n_cmpl` entries.
|
||||
|
||||
`stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`.
|
||||
|
||||
`stop`: Specify a JSON array of stopping strings.
|
||||
|
||||
Binary file not shown.
@@ -494,6 +494,18 @@ int32_t server_tokens::process_chunk(
|
||||
return 0;
|
||||
}
|
||||
|
||||
server_tokens server_tokens::clone() const {
|
||||
server_tokens res;
|
||||
res.has_mtmd = has_mtmd;
|
||||
res.tokens = tokens;
|
||||
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
|
||||
size_t idx = it->first;
|
||||
const mtmd::input_chunk_ptr & chunk = it->second;
|
||||
res.map_idx_to_media[idx] = mtmd::input_chunk_ptr(mtmd_input_chunk_copy(chunk.get()));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
//
|
||||
// tokenizer and input processing utils
|
||||
//
|
||||
@@ -745,12 +757,6 @@ json oaicompat_completion_params_parse(const json & body) {
|
||||
llama_params["stop"] = json_value(body, "stop", json::array());
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
int n_choices = json_value(body, "n", 1);
|
||||
if (n_choices != 1) {
|
||||
throw std::runtime_error("Only one completion choice is allowed");
|
||||
}
|
||||
|
||||
// Handle "echo" field
|
||||
if (json_value(body, "echo", false)) {
|
||||
throw std::runtime_error("Only no echo is supported");
|
||||
@@ -1049,12 +1055,6 @@ json oaicompat_chat_params_parse(
|
||||
llama_params["chat_parser"] = chat_params.parser;
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
int n_choices = json_value(body, "n", 1);
|
||||
if (n_choices != 1) {
|
||||
throw std::invalid_argument("Only one completion choice is allowed");
|
||||
}
|
||||
|
||||
// Handle "logprobs" field
|
||||
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
||||
if (json_value(body, "logprobs", false)) {
|
||||
|
||||
@@ -215,6 +215,8 @@ public:
|
||||
llama_pos pos,
|
||||
int32_t seq_id,
|
||||
size_t & n_tokens_out) const;
|
||||
|
||||
server_tokens clone() const;
|
||||
};
|
||||
|
||||
|
||||
|
||||
@@ -35,7 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1;
|
||||
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
||||
enum slot_state {
|
||||
SLOT_STATE_IDLE,
|
||||
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
|
||||
SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
|
||||
SLOT_STATE_STARTED, // after assigning a task and about to process prompt
|
||||
SLOT_STATE_PROCESSING_PROMPT,
|
||||
SLOT_STATE_DONE_PROMPT,
|
||||
SLOT_STATE_GENERATING,
|
||||
@@ -254,6 +255,15 @@ struct server_slot {
|
||||
generated_token_probs.push_back(token);
|
||||
}
|
||||
|
||||
// note: a slot can also be either a parent or a child
|
||||
bool is_parent() const {
|
||||
return is_processing() && task->n_children > 0;
|
||||
}
|
||||
|
||||
bool is_child() const {
|
||||
return is_processing() && task->id_parent >= 0;
|
||||
}
|
||||
|
||||
void release() {
|
||||
if (is_processing()) {
|
||||
GGML_ASSERT(task);
|
||||
@@ -383,6 +393,17 @@ struct server_slot {
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void copy_state_to(server_slot & other) const {
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
|
||||
other.n_decoded = n_decoded;
|
||||
other.n_remaining = n_remaining;
|
||||
other.i_batch = i_batch;
|
||||
other.n_prompt_tokens_cache = n_prompt_tokens_cache;
|
||||
other.n_prompt_tokens_processed = n_prompt_tokens_processed;
|
||||
other.prompt = prompt.clone();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1022,7 +1043,9 @@ struct server_context_impl {
|
||||
|
||||
slot.task = std::make_unique<const server_task>(std::move(task));
|
||||
|
||||
slot.state = SLOT_STATE_STARTED;
|
||||
slot.state = slot.is_child()
|
||||
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
|
||||
: SLOT_STATE_STARTED;
|
||||
|
||||
SLT_INF(slot, "%s", "processing task\n");
|
||||
|
||||
@@ -1684,6 +1707,12 @@ struct server_context_impl {
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
if (slot.is_parent() || slot.is_child()) {
|
||||
send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Shift context
|
||||
int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
|
||||
|
||||
@@ -2308,6 +2337,26 @@ struct server_context_impl {
|
||||
n_batch = llama_n_batch(ctx);
|
||||
|
||||
for (auto & slot : slots) {
|
||||
// may need to copy state to other slots
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
|
||||
std::vector<server_slot *> child_slots;
|
||||
for (auto & other : slots) {
|
||||
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
|
||||
child_slots.push_back(&other);
|
||||
}
|
||||
}
|
||||
|
||||
// we can only proceed if all child slots are having the correct tasks
|
||||
if (child_slots.size() == slot.task->n_children) {
|
||||
// copy state to the child slots
|
||||
for (auto & child : child_slots) {
|
||||
SLT_INF(slot, "copying state to child %d\n", child->id);
|
||||
slot.copy_state_to(*child);
|
||||
child->state = SLOT_STATE_DONE_PROMPT;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// optionally send prompt processing progress
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
@@ -2593,11 +2642,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
}
|
||||
tasks.reserve(inputs.size());
|
||||
states.reserve(inputs.size());
|
||||
int idx = 0;
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.index = idx++;
|
||||
|
||||
task.tokens = std::move(inputs[i]);
|
||||
task.params = server_task::params_from_json_cmpl(
|
||||
@@ -2612,6 +2662,18 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
task.params.oaicompat_model = ctx_server.model_name;
|
||||
states.push_back(task.params.oaicompat_chat_syntax);
|
||||
|
||||
if (task.params.n_cmpl > 1) {
|
||||
task.n_children = task.params.n_cmpl - 1;
|
||||
for (size_t j = 0; j < task.n_children; j++) {
|
||||
server_task child = task.create_child(
|
||||
task.id,
|
||||
ctx_server.queue_tasks.get_new_id(),
|
||||
idx++);
|
||||
states.push_back(child.params.oaicompat_chat_syntax);
|
||||
tasks.push_back(std::move(child));
|
||||
}
|
||||
}
|
||||
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
@@ -2638,8 +2700,21 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||
arr.push_back(res->to_json());
|
||||
}
|
||||
// if single request, return single object instead of array
|
||||
res->ok(arr.size() == 1 ? arr[0] : arr);
|
||||
GGML_ASSERT(!arr.empty() && "empty results");
|
||||
if (arr.size() == 1) {
|
||||
// if single request, return single object instead of array
|
||||
res->ok(arr[0]);
|
||||
} else if (res_type == TASK_RESPONSE_TYPE_OAI_CHAT || res_type == TASK_RESPONSE_TYPE_OAI_CMPL) {
|
||||
// if multiple results in OAI format, we need to re-format them
|
||||
json & choices = arr[0]["choices"];
|
||||
for (size_t i = 1; i < arr.size(); i++) {
|
||||
choices.push_back(std::move(arr[i]["choices"][0]));
|
||||
}
|
||||
res->ok(arr[0]);
|
||||
} else {
|
||||
// multi-results, non-OAI compat
|
||||
res->ok(arr);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// in streaming mode, the first error must be treated as non-stream response
|
||||
|
||||
@@ -175,6 +175,7 @@ task_params server_task::params_from_json_cmpl(
|
||||
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
|
||||
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
||||
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||
params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1));
|
||||
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
||||
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
||||
@@ -453,6 +454,10 @@ task_params server_task::params_from_json_cmpl(
|
||||
}
|
||||
}
|
||||
|
||||
if (params.n_cmpl > params_base.n_parallel) {
|
||||
throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np");
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
@@ -664,7 +669,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {
|
||||
|
||||
json choice {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"index", index},
|
||||
{"message", msg.to_json_oaicompat<json>()},
|
||||
};
|
||||
|
||||
@@ -1064,7 +1069,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"index", index},
|
||||
{"delta", delta},
|
||||
},
|
||||
})},
|
||||
|
||||
@@ -53,6 +53,7 @@ struct task_params {
|
||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
|
||||
int32_t n_cmpl = 1; // number of completions to generate from this prompt
|
||||
|
||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
@@ -89,6 +90,10 @@ struct server_task {
|
||||
int id_target = -1;
|
||||
int id_slot = -1;
|
||||
|
||||
// used by parallel sampling (multiple completions from same prompt)
|
||||
size_t n_children = 0; // number of tasks reusing this prompt
|
||||
int id_parent = -1;
|
||||
|
||||
// used by SERVER_TASK_TYPE_INFERENCE
|
||||
task_params params;
|
||||
server_tokens tokens;
|
||||
@@ -130,6 +135,17 @@ struct server_task {
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
server_task create_child(int id_parent, int id_child, int idx) const {
|
||||
server_task copy;
|
||||
copy.id = id_child;
|
||||
copy.index = idx;
|
||||
copy.id_parent = id_parent;
|
||||
copy.params = params;
|
||||
copy.type = type;
|
||||
copy.tokens = tokens.clone();
|
||||
return copy;
|
||||
}
|
||||
};
|
||||
|
||||
struct result_timings {
|
||||
@@ -466,6 +482,14 @@ struct server_prompt {
|
||||
int n_tokens() const {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
server_prompt clone() const {
|
||||
return server_prompt {
|
||||
tokens.clone(),
|
||||
data,
|
||||
checkpoints
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt_cache {
|
||||
|
||||
@@ -477,3 +477,22 @@ def test_return_progress(n_batch, batch_count, reuse_cache):
|
||||
assert last_progress["total"] > 0
|
||||
assert last_progress["processed"] == last_progress["total"]
|
||||
assert total_batch_count == batch_count
|
||||
|
||||
|
||||
def test_chat_completions_multiple_choices():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 8,
|
||||
"n": 2,
|
||||
"messages": [
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["choices"]) == 2
|
||||
for choice in res.body["choices"]:
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex("Suddenly", choice["message"]["content"])
|
||||
assert choice["finish_reason"] == "length"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { copyToClipboard, isIMEComposing } from '$lib/utils';
|
||||
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
|
||||
import ChatMessageUser from './ChatMessageUser.svelte';
|
||||
import ChatMessageSystem from './ChatMessageSystem.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -140,8 +141,7 @@
|
||||
}
|
||||
|
||||
function handleSaveEdit() {
|
||||
if (message.role === 'user') {
|
||||
// For user messages, trim to avoid accidental whitespace
|
||||
if (message.role === 'user' || message.role === 'system') {
|
||||
onEditWithBranching?.(message, editedContent.trim());
|
||||
} else {
|
||||
// For assistant messages, preserve exact content including trailing whitespace
|
||||
@@ -167,7 +167,28 @@
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if message.role === 'user'}
|
||||
{#if message.role === 'system'}
|
||||
<ChatMessageSystem
|
||||
bind:textareaElement
|
||||
class={className}
|
||||
{deletionInfo}
|
||||
{editedContent}
|
||||
{isEditing}
|
||||
{message}
|
||||
onCancelEdit={handleCancelEdit}
|
||||
onConfirmDelete={handleConfirmDelete}
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
onEditKeydown={handleEditKeydown}
|
||||
onEditedContentChange={handleEditedContentChange}
|
||||
{onNavigateToSibling}
|
||||
onSaveEdit={handleSaveEdit}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
/>
|
||||
{:else if message.role === 'user'}
|
||||
<ChatMessageUser
|
||||
bind:textareaElement
|
||||
class={className}
|
||||
|
||||
@@ -0,0 +1,216 @@
|
||||
<script lang="ts">
|
||||
import { Check, X } from '@lucide/svelte';
|
||||
import { Card } from '$lib/components/ui/card';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { MarkdownContent } from '$lib/components/app';
|
||||
import { INPUT_CLASSES } from '$lib/constants/input-classes';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import ChatMessageActions from './ChatMessageActions.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
message: DatabaseMessage;
|
||||
isEditing: boolean;
|
||||
editedContent: string;
|
||||
siblingInfo?: ChatMessageSiblingInfo | null;
|
||||
showDeleteDialog: boolean;
|
||||
deletionInfo: {
|
||||
totalCount: number;
|
||||
userMessages: number;
|
||||
assistantMessages: number;
|
||||
messageTypes: string[];
|
||||
} | null;
|
||||
onCancelEdit: () => void;
|
||||
onSaveEdit: () => void;
|
||||
onEditKeydown: (event: KeyboardEvent) => void;
|
||||
onEditedContentChange: (content: string) => void;
|
||||
onCopy: () => void;
|
||||
onEdit: () => void;
|
||||
onDelete: () => void;
|
||||
onConfirmDelete: () => void;
|
||||
onNavigateToSibling?: (siblingId: string) => void;
|
||||
onShowDeleteDialogChange: (show: boolean) => void;
|
||||
textareaElement?: HTMLTextAreaElement;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
message,
|
||||
isEditing,
|
||||
editedContent,
|
||||
siblingInfo = null,
|
||||
showDeleteDialog,
|
||||
deletionInfo,
|
||||
onCancelEdit,
|
||||
onSaveEdit,
|
||||
onEditKeydown,
|
||||
onEditedContentChange,
|
||||
onCopy,
|
||||
onEdit,
|
||||
onDelete,
|
||||
onConfirmDelete,
|
||||
onNavigateToSibling,
|
||||
onShowDeleteDialogChange,
|
||||
textareaElement = $bindable()
|
||||
}: Props = $props();
|
||||
|
||||
let isMultiline = $state(false);
|
||||
let messageElement: HTMLElement | undefined = $state();
|
||||
let isExpanded = $state(false);
|
||||
let contentHeight = $state(0);
|
||||
const MAX_HEIGHT = 200; // pixels
|
||||
const currentConfig = config();
|
||||
|
||||
let showExpandButton = $derived(contentHeight > MAX_HEIGHT);
|
||||
|
||||
$effect(() => {
|
||||
if (!messageElement || !message.content.trim()) return;
|
||||
|
||||
if (message.content.includes('\n')) {
|
||||
isMultiline = true;
|
||||
}
|
||||
|
||||
const resizeObserver = new ResizeObserver((entries) => {
|
||||
for (const entry of entries) {
|
||||
const element = entry.target as HTMLElement;
|
||||
const estimatedSingleLineHeight = 24;
|
||||
|
||||
isMultiline = element.offsetHeight > estimatedSingleLineHeight * 1.5;
|
||||
contentHeight = element.scrollHeight;
|
||||
}
|
||||
});
|
||||
|
||||
resizeObserver.observe(messageElement);
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
});
|
||||
|
||||
function toggleExpand() {
|
||||
isExpanded = !isExpanded;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div
|
||||
aria-label="System message with actions"
|
||||
class="group flex flex-col items-end gap-3 md:gap-2 {className}"
|
||||
role="group"
|
||||
>
|
||||
{#if isEditing}
|
||||
<div class="w-full max-w-[80%]">
|
||||
<textarea
|
||||
bind:this={textareaElement}
|
||||
bind:value={editedContent}
|
||||
class="min-h-[60px] w-full resize-none rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
|
||||
onkeydown={onEditKeydown}
|
||||
oninput={(e) => onEditedContentChange(e.currentTarget.value)}
|
||||
placeholder="Edit system message..."
|
||||
></textarea>
|
||||
|
||||
<div class="mt-2 flex justify-end gap-2">
|
||||
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="outline">
|
||||
<X class="mr-1 h-3 w-3" />
|
||||
Cancel
|
||||
</Button>
|
||||
|
||||
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
|
||||
<Check class="mr-1 h-3 w-3" />
|
||||
Send
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
{#if message.content.trim()}
|
||||
<div class="relative max-w-[80%]">
|
||||
<button
|
||||
class="group/expand w-full text-left {!isExpanded && showExpandButton
|
||||
? 'cursor-pointer'
|
||||
: 'cursor-auto'}"
|
||||
onclick={showExpandButton && !isExpanded ? toggleExpand : undefined}
|
||||
type="button"
|
||||
>
|
||||
<Card
|
||||
class="rounded-[1.125rem] !border-2 !border-dashed !border-border/50 bg-muted px-3.75 py-1.5 data-[multiline]:py-2.5"
|
||||
data-multiline={isMultiline ? '' : undefined}
|
||||
style="border: 2px dashed hsl(var(--border));"
|
||||
>
|
||||
<div
|
||||
class="relative overflow-hidden transition-all duration-300 {isExpanded
|
||||
? 'cursor-text select-text'
|
||||
: 'select-none'}"
|
||||
style={!isExpanded && showExpandButton
|
||||
? `max-height: ${MAX_HEIGHT}px;`
|
||||
: 'max-height: none;'}
|
||||
>
|
||||
{#if currentConfig.renderUserContentAsMarkdown}
|
||||
<div bind:this={messageElement} class="text-md {isExpanded ? 'cursor-text' : ''}">
|
||||
<MarkdownContent class="markdown-system-content" content={message.content} />
|
||||
</div>
|
||||
{:else}
|
||||
<span
|
||||
bind:this={messageElement}
|
||||
class="text-md whitespace-pre-wrap {isExpanded ? 'cursor-text' : ''}"
|
||||
>
|
||||
{message.content}
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
{#if !isExpanded && showExpandButton}
|
||||
<div
|
||||
class="pointer-events-none absolute right-0 bottom-0 left-0 h-48 bg-gradient-to-t from-muted to-transparent"
|
||||
></div>
|
||||
<div
|
||||
class="pointer-events-none absolute right-0 bottom-4 left-0 flex justify-center opacity-0 transition-opacity group-hover/expand:opacity-100"
|
||||
>
|
||||
<Button
|
||||
class="rounded-full px-4 py-1.5 text-xs shadow-md"
|
||||
size="sm"
|
||||
variant="outline"
|
||||
>
|
||||
Show full system message
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if isExpanded && showExpandButton}
|
||||
<div class="mb-2 flex justify-center">
|
||||
<Button
|
||||
class="rounded-full px-4 py-1.5 text-xs"
|
||||
onclick={(e) => {
|
||||
e.stopPropagation();
|
||||
toggleExpand();
|
||||
}}
|
||||
size="sm"
|
||||
variant="outline"
|
||||
>
|
||||
Collapse System Message
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
</Card>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if message.timestamp}
|
||||
<div class="max-w-[80%]">
|
||||
<ChatMessageActions
|
||||
actionsPosition="right"
|
||||
{deletionInfo}
|
||||
justify="end"
|
||||
{onConfirmDelete}
|
||||
{onCopy}
|
||||
{onDelete}
|
||||
{onEdit}
|
||||
{onNavigateToSibling}
|
||||
{onShowDeleteDialogChange}
|
||||
{siblingInfo}
|
||||
{showDeleteDialog}
|
||||
role="user"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
@@ -145,7 +145,7 @@
|
||||
|
||||
{#if message.content.trim()}
|
||||
<Card
|
||||
class="max-w-[80%] rounded-[1.125rem] bg-primary px-3.75 py-1.5 text-primary-foreground data-[multiline]:py-2.5"
|
||||
class="max-w-[80%] rounded-[1.125rem] border-none bg-primary px-3.75 py-1.5 text-primary-foreground data-[multiline]:py-2.5"
|
||||
data-multiline={isMultiline ? '' : undefined}
|
||||
>
|
||||
{#if currentConfig.renderUserContentAsMarkdown}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import { ChatMessage } from '$lib/components/app';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { getMessageSiblings } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
@@ -13,6 +14,7 @@
|
||||
let { class: className, messages = [], onUserAction }: Props = $props();
|
||||
|
||||
let allConversationMessages = $state<DatabaseMessage[]>([]);
|
||||
const currentConfig = config();
|
||||
|
||||
function refreshAllMessages() {
|
||||
const conversation = activeConversation();
|
||||
@@ -40,7 +42,12 @@
|
||||
return [];
|
||||
}
|
||||
|
||||
return messages.map((message) => {
|
||||
// Filter out system messages if showSystemMessage is false
|
||||
const filteredMessages = currentConfig.showSystemMessage
|
||||
? messages
|
||||
: messages.filter((msg) => msg.type !== 'system');
|
||||
|
||||
return filteredMessages.map((message) => {
|
||||
const siblingInfo = getMessageSiblings(allConversationMessages, message.id);
|
||||
|
||||
return {
|
||||
|
||||
@@ -36,12 +36,6 @@
|
||||
title: 'General',
|
||||
icon: Settings,
|
||||
fields: [
|
||||
{ key: 'apiKey', label: 'API Key', type: 'input' },
|
||||
{
|
||||
key: 'systemMessage',
|
||||
label: 'System Message (will be disabled if left empty)',
|
||||
type: 'textarea'
|
||||
},
|
||||
{
|
||||
key: 'theme',
|
||||
label: 'Theme',
|
||||
@@ -52,6 +46,12 @@
|
||||
{ value: 'dark', label: 'Dark', icon: Moon }
|
||||
]
|
||||
},
|
||||
{ key: 'apiKey', label: 'API Key', type: 'input' },
|
||||
{
|
||||
key: 'systemMessage',
|
||||
label: 'System Message',
|
||||
type: 'textarea'
|
||||
},
|
||||
{
|
||||
key: 'pasteLongTextToFileLen',
|
||||
label: 'Paste long text to file length',
|
||||
|
||||
@@ -95,7 +95,7 @@
|
||||
</div>
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
{@html field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
</p>
|
||||
{/if}
|
||||
{:else if field.type === 'textarea'}
|
||||
@@ -112,13 +112,28 @@
|
||||
value={String(localConfig[field.key] ?? '')}
|
||||
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
|
||||
class="min-h-[100px] w-full md:max-w-2xl"
|
||||
class="min-h-[10rem] w-full md:max-w-2xl"
|
||||
/>
|
||||
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
</p>
|
||||
{/if}
|
||||
|
||||
{#if field.key === 'systemMessage'}
|
||||
<div class="mt-3 flex items-center gap-2">
|
||||
<Checkbox
|
||||
id="showSystemMessage"
|
||||
checked={Boolean(localConfig.showSystemMessage ?? true)}
|
||||
onCheckedChange={(checked) => onConfigChange('showSystemMessage', Boolean(checked))}
|
||||
/>
|
||||
|
||||
<Label for="showSystemMessage" class="cursor-pointer text-sm font-normal">
|
||||
Show system message in conversations
|
||||
</Label>
|
||||
</div>
|
||||
{/if}
|
||||
{:else if field.type === 'select'}
|
||||
{@const selectedOption = field.options?.find(
|
||||
(opt: { value: string; label: string; icon?: Component }) =>
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import Input from '$lib/components/ui/input/input.svelte';
|
||||
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import ChatSidebarActions from './ChatSidebarActions.svelte';
|
||||
|
||||
const sidebar = Sidebar.useSidebar();
|
||||
@@ -98,6 +99,10 @@
|
||||
|
||||
await goto(`#/chat/${id}`);
|
||||
}
|
||||
|
||||
function handleStopGeneration(id: string) {
|
||||
chatStore.stopGenerationForChat(id);
|
||||
}
|
||||
</script>
|
||||
|
||||
<ScrollArea class="h-[100vh]">
|
||||
@@ -132,6 +137,7 @@
|
||||
onSelect={selectConversation}
|
||||
onEdit={handleEditConversation}
|
||||
onDelete={handleDeleteConversation}
|
||||
onStop={handleStopGeneration}
|
||||
/>
|
||||
</Sidebar.MenuItem>
|
||||
{/each}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
<script lang="ts">
|
||||
import { Trash2, Pencil, MoreHorizontal, Download, Loader2 } from '@lucide/svelte';
|
||||
import { Trash2, Pencil, MoreHorizontal, Download, Loader2, Square } from '@lucide/svelte';
|
||||
import { ActionDropdown } from '$lib/components/app';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { getAllLoadingChats } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { onMount } from 'svelte';
|
||||
@@ -12,6 +13,7 @@
|
||||
onDelete?: (id: string) => void;
|
||||
onEdit?: (id: string) => void;
|
||||
onSelect?: (id: string) => void;
|
||||
onStop?: (id: string) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -20,6 +22,7 @@
|
||||
onDelete,
|
||||
onEdit,
|
||||
onSelect,
|
||||
onStop,
|
||||
isActive = false
|
||||
}: Props = $props();
|
||||
|
||||
@@ -38,8 +41,14 @@
|
||||
onDelete?.(conversation.id);
|
||||
}
|
||||
|
||||
function handleStop(event: Event) {
|
||||
event.stopPropagation();
|
||||
onStop?.(conversation.id);
|
||||
}
|
||||
|
||||
function handleGlobalEditEvent(event: Event) {
|
||||
const customEvent = event as CustomEvent<{ conversationId: string }>;
|
||||
|
||||
if (customEvent.detail.conversationId === conversation.id && isActive) {
|
||||
handleEdit(event);
|
||||
}
|
||||
@@ -88,8 +97,28 @@
|
||||
>
|
||||
<div class="flex min-w-0 flex-1 items-center gap-2">
|
||||
{#if isLoading}
|
||||
<Loader2 class="h-3.5 w-3.5 shrink-0 animate-spin text-muted-foreground" />
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<div
|
||||
class="stop-button flex h-4 w-4 shrink-0 cursor-pointer items-center justify-center rounded text-muted-foreground transition-colors hover:text-foreground"
|
||||
onclick={handleStop}
|
||||
onkeydown={(e) => e.key === 'Enter' && handleStop(e)}
|
||||
role="button"
|
||||
tabindex="0"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<Loader2 class="loading-icon h-3.5 w-3.5 animate-spin" />
|
||||
|
||||
<Square class="stop-icon hidden h-3 w-3 fill-current text-destructive" />
|
||||
</div>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>Stop generation</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
|
||||
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<span class="truncate text-sm font-medium" onclick={handleMobileSidebarItemClick}>
|
||||
@@ -147,5 +176,25 @@
|
||||
opacity: 1 !important;
|
||||
}
|
||||
}
|
||||
|
||||
.stop-button {
|
||||
:global(.stop-icon) {
|
||||
display: none;
|
||||
}
|
||||
|
||||
:global(.loading-icon) {
|
||||
display: block;
|
||||
}
|
||||
}
|
||||
|
||||
&:is(:hover) .stop-button {
|
||||
:global(.stop-icon) {
|
||||
display: block;
|
||||
}
|
||||
|
||||
:global(.loading-icon) {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -19,8 +19,10 @@ export { default as ChatMessage } from './chat/ChatMessages/ChatMessage.svelte';
|
||||
export { default as ChatMessageActions } from './chat/ChatMessages/ChatMessageActions.svelte';
|
||||
export { default as ChatMessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
|
||||
export { default as ChatMessageStatistics } from './chat/ChatMessages/ChatMessageStatistics.svelte';
|
||||
export { default as ChatMessageSystem } from './chat/ChatMessages/ChatMessageSystem.svelte';
|
||||
export { default as ChatMessageThinkingBlock } from './chat/ChatMessages/ChatMessageThinkingBlock.svelte';
|
||||
export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte';
|
||||
export { default as MessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
|
||||
|
||||
export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte';
|
||||
export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte';
|
||||
|
||||
@@ -337,19 +337,23 @@
|
||||
line-height: 1.75;
|
||||
}
|
||||
|
||||
div :global(:is(h1, h2, h3, h4, h5, h6):first-child) {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
/* Headers with consistent spacing */
|
||||
div :global(h1) {
|
||||
font-size: 1.875rem;
|
||||
font-weight: 700;
|
||||
margin: 1.5rem 0 0.75rem 0;
|
||||
line-height: 1.2;
|
||||
margin: 1.5rem 0 0.75rem 0;
|
||||
}
|
||||
|
||||
div :global(h2) {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 600;
|
||||
margin: 1.25rem 0 0.5rem 0;
|
||||
line-height: 1.3;
|
||||
margin: 1.25rem 0 0.5rem 0;
|
||||
}
|
||||
|
||||
div :global(h3) {
|
||||
|
||||
@@ -3,6 +3,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
||||
// Do not use nested objects, keep it single level. Prefix the key if you need to group them.
|
||||
apiKey: '',
|
||||
systemMessage: '',
|
||||
showSystemMessage: true,
|
||||
theme: 'system',
|
||||
showThoughtInProgress: false,
|
||||
showToolCalls: false,
|
||||
@@ -42,8 +43,9 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
||||
};
|
||||
|
||||
export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
apiKey: 'Set the API Key if you are using --api-key option for the server.',
|
||||
apiKey: 'Set the API Key if you are using <code>--api-key</code> option for the server.',
|
||||
systemMessage: 'The starting message that defines how model should behave.',
|
||||
showSystemMessage: 'Display the system message at the top of each conversation.',
|
||||
theme:
|
||||
'Choose the color theme for the interface. You can choose between System (follows your device settings), Light, or Dark.',
|
||||
pasteLongTextToFileLen:
|
||||
|
||||
@@ -89,7 +89,6 @@ export class ChatService {
|
||||
custom,
|
||||
timings_per_token,
|
||||
// Config options
|
||||
systemMessage,
|
||||
disableReasoningFormat
|
||||
} = options;
|
||||
|
||||
@@ -103,6 +102,7 @@ export class ChatService {
|
||||
}
|
||||
})
|
||||
.filter((msg) => {
|
||||
// Filter out empty system messages
|
||||
if (msg.role === 'system') {
|
||||
const content = typeof msg.content === 'string' ? msg.content : '';
|
||||
|
||||
@@ -112,10 +112,8 @@ export class ChatService {
|
||||
return true;
|
||||
});
|
||||
|
||||
const processedMessages = ChatService.injectSystemMessage(normalizedMessages, systemMessage);
|
||||
|
||||
const requestBody: ApiChatCompletionRequest = {
|
||||
messages: processedMessages.map((msg: ApiChatMessageData) => ({
|
||||
messages: normalizedMessages.map((msg: ApiChatMessageData) => ({
|
||||
role: msg.role,
|
||||
content: msg.content
|
||||
})),
|
||||
@@ -677,46 +675,6 @@ export class ChatService {
|
||||
// Utilities
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Injects a system message at the beginning of the conversation if provided.
|
||||
* Checks for existing system messages to avoid duplication.
|
||||
*
|
||||
* @param messages - Array of chat messages to process
|
||||
* @param systemMessage - Optional system message to inject
|
||||
* @returns Array of messages with system message injected at the beginning if provided
|
||||
* @private
|
||||
*/
|
||||
private static injectSystemMessage(
|
||||
messages: ApiChatMessageData[],
|
||||
systemMessage?: string
|
||||
): ApiChatMessageData[] {
|
||||
const trimmedSystemMessage = systemMessage?.trim();
|
||||
|
||||
if (!trimmedSystemMessage) {
|
||||
return messages;
|
||||
}
|
||||
|
||||
if (messages.length > 0 && messages[0].role === 'system') {
|
||||
if (messages[0].content !== trimmedSystemMessage) {
|
||||
const updatedMessages = [...messages];
|
||||
updatedMessages[0] = {
|
||||
role: 'system',
|
||||
content: trimmedSystemMessage
|
||||
};
|
||||
return updatedMessages;
|
||||
}
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
const systemMsg: ApiChatMessageData = {
|
||||
role: 'system',
|
||||
content: trimmedSystemMessage
|
||||
};
|
||||
|
||||
return [systemMsg, ...messages];
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses error response and creates appropriate error with context information
|
||||
* @param response - HTTP response object
|
||||
|
||||
@@ -166,6 +166,49 @@ export class DatabaseService {
|
||||
return rootMessage.id;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a system prompt message for a conversation.
|
||||
*
|
||||
* @param convId - Conversation ID
|
||||
* @param systemPrompt - The system prompt content (must be non-empty)
|
||||
* @param parentId - Parent message ID (typically the root message)
|
||||
* @returns The created system message
|
||||
* @throws Error if systemPrompt is empty
|
||||
*/
|
||||
static async createSystemMessage(
|
||||
convId: string,
|
||||
systemPrompt: string,
|
||||
parentId: string
|
||||
): Promise<DatabaseMessage> {
|
||||
const trimmedPrompt = systemPrompt.trim();
|
||||
if (!trimmedPrompt) {
|
||||
throw new Error('Cannot create system message with empty content');
|
||||
}
|
||||
|
||||
const systemMessage: DatabaseMessage = {
|
||||
id: uuid(),
|
||||
convId,
|
||||
type: 'system',
|
||||
timestamp: Date.now(),
|
||||
role: 'system',
|
||||
content: trimmedPrompt,
|
||||
parent: parentId,
|
||||
thinking: '',
|
||||
children: []
|
||||
};
|
||||
|
||||
await db.messages.add(systemMessage);
|
||||
|
||||
const parentMessage = await db.messages.get(parentId);
|
||||
if (parentMessage) {
|
||||
await db.messages.update(parentId, {
|
||||
children: [...parentMessage.children, systemMessage.id]
|
||||
});
|
||||
}
|
||||
|
||||
return systemMessage;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a conversation and all its messages.
|
||||
*
|
||||
|
||||
@@ -2,7 +2,11 @@ import { DatabaseService, ChatService } from '$lib/services';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { contextSize, isRouterMode } from '$lib/stores/server.svelte';
|
||||
import { selectedModelName, modelsStore } from '$lib/stores/models.svelte';
|
||||
import {
|
||||
selectedModelName,
|
||||
modelsStore,
|
||||
selectedModelContextSize
|
||||
} from '$lib/stores/models.svelte';
|
||||
import {
|
||||
normalizeModelName,
|
||||
filterByLeafNodeId,
|
||||
@@ -261,6 +265,13 @@ class ChatStore {
|
||||
return activeState.contextTotal;
|
||||
}
|
||||
|
||||
if (isRouterMode()) {
|
||||
const modelContextSize = selectedModelContextSize();
|
||||
if (modelContextSize && modelContextSize > 0) {
|
||||
return modelContextSize;
|
||||
}
|
||||
}
|
||||
|
||||
const propsContextSize = contextSize();
|
||||
if (propsContextSize && propsContextSize > 0) {
|
||||
return propsContextSize;
|
||||
@@ -458,6 +469,14 @@ class ChatStore {
|
||||
onError?: (error: Error) => void,
|
||||
modelOverride?: string | null
|
||||
): Promise<void> {
|
||||
// Ensure model props are cached before streaming (for correct n_ctx in processing info)
|
||||
if (isRouterMode()) {
|
||||
const modelName = modelOverride || selectedModelName();
|
||||
if (modelName && !modelsStore.getModelProps(modelName)) {
|
||||
await modelsStore.fetchModelProps(modelName);
|
||||
}
|
||||
}
|
||||
|
||||
let streamedContent = '';
|
||||
let streamedReasoningContent = '';
|
||||
let streamedToolCallContent = '';
|
||||
@@ -624,6 +643,22 @@ class ChatStore {
|
||||
this.clearChatStreaming(currentConv.id);
|
||||
|
||||
try {
|
||||
if (isNewConversation) {
|
||||
const rootId = await DatabaseService.createRootMessage(currentConv.id);
|
||||
const currentConfig = config();
|
||||
const systemPrompt = currentConfig.systemMessage?.toString().trim();
|
||||
|
||||
if (systemPrompt) {
|
||||
const systemMessage = await DatabaseService.createSystemMessage(
|
||||
currentConv.id,
|
||||
systemPrompt,
|
||||
rootId
|
||||
);
|
||||
|
||||
conversationsStore.addMessageToActive(systemMessage);
|
||||
}
|
||||
}
|
||||
|
||||
const userMessage = await this.addMessage('user', content, 'text', '-1', extras);
|
||||
if (!userMessage) throw new Error('Failed to add user message');
|
||||
if (isNewConversation && content)
|
||||
@@ -666,13 +701,17 @@ class ChatStore {
|
||||
|
||||
if (!activeConv) return;
|
||||
|
||||
await this.savePartialResponseIfNeeded(activeConv.id);
|
||||
await this.stopGenerationForChat(activeConv.id);
|
||||
}
|
||||
|
||||
async stopGenerationForChat(convId: string): Promise<void> {
|
||||
await this.savePartialResponseIfNeeded(convId);
|
||||
|
||||
this.stopStreaming();
|
||||
this.abortRequest(activeConv.id);
|
||||
this.setChatLoading(activeConv.id, false);
|
||||
this.clearChatStreaming(activeConv.id);
|
||||
this.clearProcessingState(activeConv.id);
|
||||
this.abortRequest(convId);
|
||||
this.setChatLoading(convId, false);
|
||||
this.clearChatStreaming(convId);
|
||||
this.clearProcessingState(convId);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -999,14 +1038,20 @@ class ChatStore {
|
||||
const activeConv = conversationsStore.activeConversation;
|
||||
if (!activeConv || this.isLoading) return;
|
||||
|
||||
const result = this.getMessageByIdWithRole(messageId, 'user');
|
||||
let result = this.getMessageByIdWithRole(messageId, 'user');
|
||||
|
||||
if (!result) {
|
||||
result = this.getMessageByIdWithRole(messageId, 'system');
|
||||
}
|
||||
|
||||
if (!result) return;
|
||||
const { message: msg } = result;
|
||||
|
||||
try {
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
|
||||
const isFirstUserMessage = rootMessage && msg.parent === rootMessage.id;
|
||||
const isFirstUserMessage =
|
||||
msg.role === 'user' && rootMessage && msg.parent === rootMessage.id;
|
||||
|
||||
const parentId = msg.parent || rootMessage?.id;
|
||||
if (!parentId) return;
|
||||
@@ -1037,7 +1082,10 @@ class ChatStore {
|
||||
);
|
||||
}
|
||||
await conversationsStore.refreshActiveMessages();
|
||||
await this.generateResponseForMessage(newMessage.id);
|
||||
|
||||
if (msg.role === 'user') {
|
||||
await this.generateResponseForMessage(newMessage.id);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to edit message with branching:', error);
|
||||
}
|
||||
|
||||
@@ -158,6 +158,22 @@ class ModelsStore {
|
||||
return this.modelPropsCache.get(modelId) ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get context size (n_ctx) for a specific model from cached props
|
||||
*/
|
||||
getModelContextSize(modelId: string): number | null {
|
||||
const props = this.modelPropsCache.get(modelId);
|
||||
return props?.default_generation_settings?.n_ctx ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get context size for the currently selected model or null if no model is selected
|
||||
*/
|
||||
get selectedModelContextSize(): number | null {
|
||||
if (!this.selectedModelName) return null;
|
||||
return this.getModelContextSize(this.selectedModelName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if props are being fetched for a model
|
||||
*/
|
||||
@@ -579,3 +595,4 @@ export const loadedModelIds = () => modelsStore.loadedModelIds;
|
||||
export const loadingModelIds = () => modelsStore.loadingModelIds;
|
||||
export const propsCacheVersion = () => modelsStore.propsCacheVersion;
|
||||
export const singleModelName = () => modelsStore.singleModelName;
|
||||
export const selectedModelContextSize = () => modelsStore.selectedModelContextSize;
|
||||
|
||||
2
tools/server/webui/src/lib/types/chat.d.ts
vendored
2
tools/server/webui/src/lib/types/chat.d.ts
vendored
@@ -1,4 +1,4 @@
|
||||
export type ChatMessageType = 'root' | 'text' | 'think';
|
||||
export type ChatMessageType = 'root' | 'text' | 'think' | 'system';
|
||||
export type ChatRole = 'user' | 'assistant' | 'system';
|
||||
|
||||
export interface ChatUploadedFile {
|
||||
|
||||
Reference in New Issue
Block a user