Compare commits

...

20 Commits

Author SHA1 Message Date
Ruben Ortlam
d5344395d0 benchmark 2026-04-08 18:26:50 +02:00
Martin Klacer
5c4aae66e1 devops: kleidiai: provide KleidiAI-Enabled ARM Release Artifact (#21259)
* Unified macOS release setup with strategy-matrix block
 * Added KleidiAI arm64 macOS release definition


Change-Id: I05520889ffc646488a178d06817a17f29274465a

Signed-off-by: Martin Klacer <martin.klacer@arm.com>
2026-04-08 13:06:12 +08:00
Aman Gupta
c5ce4bc227 CUDA: make cuda graphs props check faster (#21472)
* CUDA: compute fast hash instead of expensive props check

* use seen node

* use memcp
2026-04-08 09:05:51 +08:00
iacopPBK
66c4f9ded0 ggml-cuda: ds_read_b128 for q4_0 and q4_1 mmq kernels (#21168)
* ds_read_b128 for q4_0 and q4_1 mmq kernels

     Current for loop generates ds_read_b32 instructions with hip compiler, the new solution generates ds_read_b128 instructions for the same operation, saving some LDS bandwidth. Tested on MI50 and RX6800XT, its faster on both.

* Vectorized lds load update: used ggml_cuda_get_max_cpy_bytes and ggml_cuda_memcpy_1 functions for generic implementation

* Explicit for loop in mmq, renamed vec into tmp

* Fixed max_cpy usage in the loading loop

* Fixed typo in q4_1 kernel

* Update ggml/src/ggml-cuda/mmq.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/mmq.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/mmq.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Renoved trailing white line 500

* Update mmq.cuh removed other whitelines

* Remove trailing whitespaces

---------

Co-authored-by: iacopPBK <iacopPBK@users.noreply.github.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: iacopPBK <iacop@deneb.com>
2026-04-07 21:47:42 +02:00
Daniel Bevenius
93bdc61563 gguf-py : fix missing comma after bad merge in tensor-mapping (#21558)
This commit adds a missing comma in the vision encoder attention qkv
block.

The motivation for this change is that without the comma there will be
a string concatenation of the Kimi-K2.5 and the Nemotron Nano v2 VL
tensor mappings which will be broken.
2026-04-07 21:24:25 +02:00
Georgi Gerganov
4eb19514dd kv-cache : support attention rotation for heterogeneous iSWA (#21513)
* kv-cache : support attention rotation for heterogeneous iSWA

* cont : remove assert
2026-04-07 20:31:28 +03:00
Reese Levine
957d717ce5 ggml-webgpu: parameterize submission size and add iOS specific limits (#21533)
* Work towards removing bitcast

* Move rest of existing types over

* Add timeout back to wait and remove synchronous set_tensor/memset_tensor

* move to unpackf16 for wider compatibility

* cleanup

* Remove deadlock condition in free_bufs

* Start work on removing parameter buffer pools

* Simplify and optimize further

* simplify profile futures

* Fix stride

* Try using a single command buffer per batch

* formatting

* Add parameters for different browsers in-flight submissions

* Update handling of batch size too

* Throttle ios as much as possible

* Increase timeout for llvm-pipe testing
2026-04-07 20:30:01 +03:00
Aman Gupta
de1aa6fa73 CUDA: check for buffer overlap before fusing (#21566)
* CUDA: check for buffer overlap before fusing

* use ggml_cuda_check_fusion_memory_ranges
2026-04-08 00:57:04 +08:00
Aaron Teo
69c28f1547 llama-server: fix model params not propagated (#21509)
Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
2026-04-07 21:39:41 +08:00
Son H. Nguyen
0d049d6a92 unicode : add custom Qwen2 regex handler to fix segfault on long input (#21257)
* unicode : add custom Qwen2 regex handler to fix segfault on long input

std::regex uses recursive backtracking internally, which causes a stack
overflow (segfault) when tokenizing long sequences of repeated characters
(e.g. 43K 'A's). The Qwen2 tokenizer regex differs from Llama3 only in
the digit pattern (\p{N} vs \p{N}{1,3}), so it was falling through to
the std::regex fallback path instead of using a custom handler.

Add unicode_regex_split_custom_qwen2() following the established pattern
used by gpt2, llama3, kimi_k2, and afmoe custom handlers.

Closes: https://github.com/ggml-org/llama.cpp/issues/21113

* cont : remove TODO comment

* cont : update comment to reflect original regex

* use the correct regex in the comment this time... [no ci]

---------

Co-authored-by: Aldehir Rojas <hello@alde.dev>
2026-04-07 16:13:38 +03:00
Johannes Gäßler
a8ec0df461 llama: remove per-arch tensor name lists (#21531) 2026-04-07 15:02:03 +02:00
Georgi Gerganov
e8f5082697 server : fix restore for checkpoints with pos_min == 0 (#21510) 2026-04-07 15:29:17 +03:00
Georgi Gerganov
22fc79134e ggml : deprecate GGML_OP_ADD1 (#21363)
* ggml : deprecate GGML_OP_ADD1

* cont : remove tests

* cont : re-enable vulkan check
2026-04-07 15:28:27 +03:00
Tom Overlund
2a619f6fbc ggml: Vulkan build, Linux -- output error string for errno on fork failure (#20868) (#20904) 2026-04-07 13:54:55 +02:00
mkoker
edd4d9bca5 vulkan: add FA dequant for q4_1, q5_0, q5_1, iq4_nl (#21029)
Add dequantize4() implementations for Q4_1, Q5_0, Q5_1, and IQ4_NL
in the flash attention base shader. Register them in the shader
generator, pipeline creation, and enable in the scalar/coopmat1 FA
support check.
2026-04-07 13:41:29 +02:00
Aldehir Rojas
482192f12d webui : store reasoning_content so it is sent back in subsequent requests (#21249) 2026-04-07 13:32:44 +02:00
Antoine Viallon
71a81f6fcc ggml-cuda : fix CDNA2 compute capability constant for gfx90a (MI210) (#21519)
GGML_CUDA_CC_CDNA2 was set to 0x910
Fix by setting the constant to 0x90a to match the actual gfx90a ISA.
2026-04-07 12:18:55 +02:00
Aleksander Grygier
ecce0087da fix: Detect streaming state in reasoning content blocks (#21549) 2026-04-07 12:04:41 +02:00
Kabir08
d1f82e382d Fix rtl text rendering (#21382)
* Fix Arabic RTL text rendering in web UI

- Add dir='auto' attributes to markdown containers and blocks
- Implement post-processing to add dir='auto' to all text elements
- Replace directional CSS properties with logical properties for proper RTL list alignment
- Ensure bidirectional text support for mixed Arabic/English content

* Clean up commented duplicate function

Remove the commented-out duplicate transformMdastNode function
that was left over from refactoring.

* Fix Arabic RTL text rendering in web UI

- Add dir='auto' attributes to markdown containers and blocks
- Implement post-processing to add dir='auto' to all text elements
- Replace directional CSS properties with logical properties for proper RTL list alignment
- Minor code formatting improvements

This ensures bidirectional text support for mixed Arabic/English content in the llama.cpp web UI.

* Implement rehype plugin for comprehensive RTL text support

- Add rehypeRtlSupport plugin that applies dir='auto' to all elements with children
- Replace DOMParser-based approach with efficient HAST tree processing
- Remove hardcoded element lists for better maintainability
- Ensure proper bidirectional text rendering for mixed RTL/LTR content

* Fix RTL text rendering with rehype plugin and cleanup

* fix: prettier formatting
2026-04-07 11:37:20 +02:00
PMZFX
0988accf82 [SYCL] Add Q8_0 reorder optimization (~3x tg speedup on Intel Arc) (#21527)
Extend the existing reorder optimization to Q8_0. The reorder
separates scale factors from weight data for coalesced memory
access -- was implemented for Q4_0/Q4_K/Q6_K but Q8_0 was missing.

On Arc Pro B70 (Xe2), Q8_0 tg goes from 4.88 to 15.24 t/s (3.1x)
on Qwen3.5-27B. BW utilization: 21% -> 66%.

The key fix beyond the kernels: Q8_0 was missing from the type
check in ggml_backend_sycl_buffer_init_tensor() that allocates
the extra struct carrying the reorder flag -- so the optimization
was silently skipped.

AI (Claude) was used to assist with root cause investigation and
writing the kernel code. All code was human-reviewed and tested
on real hardware.

Fixes: #21517
2026-04-07 16:12:49 +08:00
39 changed files with 1660 additions and 2514 deletions

View File

@@ -36,8 +36,26 @@ env:
CMAKE_ARGS: "-DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_TOOLS=ON -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON"
jobs:
macOS-arm64:
runs-on: macos-14
macOS-cpu:
strategy:
matrix:
include:
- build: 'arm64'
arch: 'arm64'
os: macos-14
defines: "-DGGML_METAL_USE_BF16=ON -DGGML_METAL_EMBED_LIBRARY=ON"
- build: 'arm64-kleidiai'
arch: 'arm64'
os: macos-14
defines: "-DGGML_METAL_USE_BF16=ON -DGGML_METAL_EMBED_LIBRARY=ON -DGGML_CPU_KLEIDIAI=ON"
- build: 'x64'
arch: 'x64'
os: macos-15-intel
# Metal is disabled on x64 due to intermittent failures with Github runners not having a GPU:
# https://github.com/ggml-org/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313
defines: "-DGGML_METAL=OFF -DCMAKE_OSX_DEPLOYMENT_TARGET=13.3"
runs-on: ${{ matrix.os }}
steps:
- name: Clone
@@ -49,7 +67,7 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.21
with:
key: macOS-latest-arm64
key: macOS-latest-${{ matrix.arch }}
evict-old-files: 1d
- name: Build
@@ -57,13 +75,11 @@ jobs:
run: |
sysctl -a
cmake -B build \
${{ matrix.defines }} \
-DCMAKE_INSTALL_RPATH='@loader_path' \
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_BUILD_BORINGSSL=ON \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DGGML_RPC=ON \
${{ env.CMAKE_ARGS }}
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
@@ -75,61 +91,13 @@ jobs:
id: pack_artifacts
run: |
cp LICENSE ./build/bin/
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-${{ matrix.build }}.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz
name: llama-bin-macos-arm64.tar.gz
macOS-x64:
runs-on: macos-15-intel
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: ccache
uses: ggml-org/ccache-action@v1.2.21
with:
key: macOS-latest-x64
evict-old-files: 1d
- name: Build
id: cmake_build
run: |
sysctl -a
# Metal is disabled due to intermittent failures with Github runners not having a GPU:
# https://github.com/ggml-org/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313
cmake -B build \
-DCMAKE_INSTALL_RPATH='@loader_path' \
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_BUILD_BORINGSSL=ON \
-DGGML_METAL=OFF \
-DGGML_RPC=ON \
-DCMAKE_OSX_DEPLOYMENT_TARGET=13.3
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name
- name: Pack artifacts
id: pack_artifacts
run: |
cp LICENSE ./build/bin/
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz
name: llama-bin-macos-x64.tar.gz
path: llama-${{ steps.tag.outputs.name }}-bin-macos-${{ matrix.build }}.tar.gz
name: llama-bin-macos-${{ matrix.build }}.tar.gz
ubuntu-cpu:
strategy:
@@ -1003,8 +971,7 @@ jobs:
- ubuntu-cpu
- ubuntu-vulkan
- ubuntu-24-openvino
- macOS-arm64
- macOS-x64
- macOS-cpu
- ios-xcode-build
- openEuler-cann
@@ -1079,6 +1046,7 @@ jobs:
**macOS/iOS:**
- [macOS Apple Silicon (arm64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz)
- [macOS Apple Silicon (arm64, KleidiAI enabled)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-arm64-kleidiai.tar.gz)
- [macOS Intel (x64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz)
- [iOS XCFramework](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-xcframework.zip)

View File

@@ -223,6 +223,7 @@ option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug ou
option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF)
option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF)
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
option(GGML_VULKAN_COPY_TESTS "ggml: run Vulkan cross-device copy benchmarks" OFF)
option(GGML_WEBGPU "ggml: use WebGPU" OFF)
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF)

View File

@@ -902,15 +902,17 @@ extern "C" {
struct ggml_tensor * b,
struct ggml_tensor * ids);
GGML_API struct ggml_tensor * ggml_add1(
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_add1(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * b),
"use ggml_add instead");
GGML_API struct ggml_tensor * ggml_add1_inplace(
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_add1_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * b),
"use ggml_add_inplace instead");
// dst = a
// view(dst, nb1, nb2, nb3, offset) += b

View File

@@ -65,7 +65,7 @@
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x90a) // MI210 (gfx90a), minimum acc register renaming
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32
@@ -1157,19 +1157,6 @@ struct ggml_tensor_extra_gpu {
#define USE_CUDA_GRAPH
#endif
struct ggml_cuda_graph_node_properties {
void * node_data;
ggml_op node_op;
enum ggml_type node_type;
int32_t flags;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_data[GGML_MAX_SRC];
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
};
static_assert(std::is_trivial<ggml_cuda_graph_node_properties>::value, "ggml_cuda_graph_node_properties must be trivial");
struct ggml_cuda_graph {
#ifdef USE_CUDA_GRAPH
~ggml_cuda_graph() {
@@ -1186,13 +1173,7 @@ struct ggml_cuda_graph {
std::vector<cudaGraphNode_t> nodes;
bool disable_due_to_gpu_arch = false;
bool warmup_complete = false;
std::vector<ggml_cuda_graph_node_properties> props;
// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
// they properties also have to match in order to be able to safely reuse a CUDA graph
// ref: https://github.com/ggml-org/llama.cpp/pull/18583
// ref: https://github.com/ggml-org/llama.cpp/pull/19165
std::vector<ggml_cuda_graph_node_properties> extra;
std::vector<ggml_tensor> nodes_copy;
bool is_enabled() const {
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);

View File

@@ -82,7 +82,6 @@
#include <cstdlib>
#include <string>
#include <vector>
#include <unordered_set>
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
@@ -2969,74 +2968,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
return use_cuda_graph;
}
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
props->node_data = node->data;
props->node_op = node->op;
props->node_type = node->type;
props->flags = node->flags;
for (int i = 0; i < GGML_MAX_DIMS; i++) {
props->ne[i] = node->ne[i];
props->nb[i] = node->nb[i];
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (!node->src[i]) {
continue;
}
props->src_data[i] = node->src[i]->data;
}
memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
return false;
}
if (node->op != props->node_op) {
return false;
}
if (node->type != props->node_type) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != props->ne[i]) {
return false;
}
if (node->nb[i] != props->nb[i]) {
return false;
}
}
if (node->op != GGML_OP_VIEW) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (!node->src[i]) {
if (props->src_data[i] != nullptr) {
return false;
}
continue;
}
if (node->src[i]->data != props->src_data[i]) {
return false;
}
}
}
if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
return false;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) {
return false;
}
return true;
}
static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
return cgraph->nodes[0];
}
@@ -3048,52 +2979,18 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
// Check if the graph size has changed
if (graph->props.size() != (size_t)cgraph->n_nodes) {
if ((int)graph->nodes_copy.size() != cgraph->n_nodes) {
res = true;
graph->props.resize(cgraph->n_nodes);
graph->nodes_copy.resize(cgraph->n_nodes);
}
// Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token
std::unordered_set<ggml_tensor *> seen_node;
std::vector<ggml_tensor *> srcs_extra;
for (int i = 0; i < cgraph->n_nodes; i++) {
bool props_match = true;
seen_node.insert(cgraph->nodes[i]);
if (!res) {
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
}
if (!props_match) {
res = true;
}
ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
if (src && seen_node.find(src) == seen_node.end()) {
srcs_extra.push_back(src);
if (memcmp(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
res = true;
}
}
}
if (graph->extra.size() != (size_t) srcs_extra.size()) {
res = true;
graph->extra.resize(srcs_extra.size());
}
for (size_t i = 0; i < srcs_extra.size(); ++i) {
bool props_match = true;
if (!res) {
props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
}
if (!props_match) {
res = true;
}
ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
memcpy(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor));
}
return res;
@@ -3308,6 +3205,71 @@ static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int nod
return true;
}
// returns whether the write (out) nodes overwrite the read nodes in operation
static bool ggml_cuda_check_fusion_memory_ranges(const ggml_cgraph * cgraph,
const int node_idx,
const int node_count,
const int * out_nodes,
const int out_count,
const bool is_topk_moe = false) {
auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
const int64_t a_start = (int64_t) a->data;
const int64_t a_end = a_start + ggml_backend_buft_get_alloc_size(a->buffer->buft, a);
const int64_t b_start = (int64_t) b->data;
const int64_t b_end = b_start + ggml_backend_buft_get_alloc_size(b->buffer->buft, b);
if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
return true;
}
return false;
};
bool is_ok = true;
// exception for topk-moe, as each row is read entirely before writing
if (ggml_nrows(cgraph->nodes[node_idx]) == 1 && is_topk_moe) {
return true;
}
for (int i = 0; i < out_count; ++i) {
const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];
for (int j = node_idx; j < node_idx + node_count; ++j) {
// Loop over all srcs of all nodes in the fusion. If the src overlaps
// the destination and the src is not an intermediate node that's being
// elided, then disable fusion.
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];
if (!src || src->op == GGML_OP_NONE) {
continue;
}
if (nodes_overlap(dst, src)) {
bool found = false;
for (int k = node_idx; k < j; ++k) {
if (cgraph->nodes[k] == src) {
found = true;
break;
}
}
if (!found) {
is_ok = false;
break;
}
}
}
}
}
return is_ok;
}
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
int node_idx,
std::initializer_list<enum ggml_op> ops,
@@ -3337,7 +3299,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
const ggml_tensor * glu = cgraph->nodes[node_idx + 4];
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
return true;
int out_nodes[] = { node_idx + 4 };
return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1);
}
}
@@ -3348,7 +3311,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
return true;
int out_nodes[] = { node_idx + 2 };
return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1);
}
}
@@ -3474,69 +3438,6 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
return false;
}
// returns whether the write (out) nodes overwrite the read nodes in operation
static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph,
int node_idx,
int node_count,
int * out_nodes,
int out_count) {
auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
const int64_t a_start = (int64_t) a->data;
const int64_t a_end = a_start + ggml_nbytes(a);
const int64_t b_start = (int64_t) b->data;
const int64_t b_end = b_start + ggml_nbytes(b);
if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
return true;
}
return false;
};
bool is_ok = true;
// for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok
if (ggml_nrows(cgraph->nodes[node_idx]) == 1) {
return true;
}
for (int i = 0; i < out_count; ++i) {
const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];
for (int j = node_idx; j < node_idx + node_count; ++j) {
// Loop over all srcs of all nodes in the fusion. If the src overlaps
// the destination and the src is not an intermediate node that's being
// elided, then disable fusion.
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];
if (!src || src->op == GGML_OP_NONE) {
continue;
}
if (nodes_overlap(dst, src)) {
bool found = false;
for (int k = node_idx; k < j; ++k) {
if (cgraph->nodes[k] == src) {
found = true;
break;
}
}
if (!found) {
is_ok = false;
break;
}
}
}
}
}
return is_ok;
}
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
bool graph_evaluated_or_captured = false;
@@ -3734,7 +3635,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
@@ -3750,7 +3651,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
int out_nodes[2] = { i + 1, i + 5 };
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;

View File

@@ -386,17 +386,25 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
int u[2*VDR_Q4_0_Q8_1_MMQ];
#pragma unroll
for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
constexpr int mcpy_int = max_cpy / sizeof(int);
static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ");
int tmp0[4], tmp1[4];
#pragma unroll
for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) {
ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] );
ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0 + l0 * mcpy_int]);
}
u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
@@ -489,17 +497,25 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
int u[2*VDR_Q4_1_Q8_1_MMQ];
#pragma unroll
for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
constexpr int mcpy_int = max_cpy / sizeof(int);
static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ");
int tmp0[4], tmp1[4];
#pragma unroll
for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) {
ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] );
ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1 + l0 * mcpy_int]);
}
u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
@@ -4170,3 +4186,4 @@ void ggml_cuda_op_mul_mat_q(
const int64_t src1_padded_row_size, cudaStream_t stream);
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);

View File

@@ -143,6 +143,22 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib,
#endif // GGML_SYCL_F16
}
static __dpct_inline__ void dequantize_q8_0_reorder(const void *d_ptr, const int64_t ib, const void *qs,
const int iqs, dfloat2 &v) {
const dfloat d = (const dfloat)*((const sycl::half*)d_ptr + ib);
v.x() = ((const int8_t *)qs)[iqs + 0];
v.y() = ((const int8_t *)qs)[iqs + 1];
#ifdef GGML_SYCL_F16
v.s0() *= d;
v.s1() *= d;
#else
v.x() *= d;
v.y() *= d;
#endif // GGML_SYCL_F16
}
static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
const block_q8_0 * x = (const block_q8_0 *) vx;

View File

@@ -972,6 +972,103 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
}
}
static void dequantize_mul_mat_vec_q8_0_sycl_reorder(const void *vx, const dfloat *y,
float *dst, const int ncols,
const int nrows,
dpct::queue_ptr stream) {
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
// Q8_0 reorder layout: [all qs (ncols*nrows bytes)][all d values]
// Cannot reuse dequantize_mul_mat_vec_reorder template because it has
// Q4_0-specific constants hardcoded (d_ptr offset and qs stride).
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);
if (row >= nrows) return;
const int tid = item_ct1.get_local_id(2);
const int iter_stride = 8*2*GGML_SYCL_DMMV_X;
const int vals_per_iter = iter_stride / WARP_SIZE;
const int ncols_left = ncols % (QK8_0*WARP_SIZE);
const int ncols_align = ncols - ncols_left;
#ifdef GGML_SYCL_F16
sycl::half2 tmp = {0.0f, 0.0f};
#else
float tmp = 0.0f;
#endif
const char *d_ptr = (const char*)vx + ncols*nrows; // d after all qs
int i = 0;
for (i = 0; i < ncols_align; i += iter_stride) {
const int col = i + vals_per_iter*tid;
const int ib = (row*ncols + col)/QK8_0;
const int iqs = col % QK8_0;
#pragma unroll
for (int j = 0; j < vals_per_iter; j += 2) {
dfloat2 v;
dequantize_q8_0_reorder((const void *)d_ptr, ib, (const void *)vx,
ib * QK8_0 + iqs + j, v);
#ifdef GGML_SYCL_F16
dfloat2 t1{y[col + j + 0], y[col + j + 1]};
tmp += v * t1;
#else
tmp += v.x() * y[col + j + 0];
tmp += v.y() * y[col + j + 1];
#endif
}
}
// handle remaining columns
for (; i < ncols; i += iter_stride) {
if (tid >= ncols_left/QK8_0) continue;
const int col = i + vals_per_iter*tid;
const int ib = (row*ncols + col)/QK8_0;
const int iqs = col % QK8_0;
#pragma unroll
for (int j = 0; j < vals_per_iter; j += 2) {
dfloat2 v;
dequantize_q8_0_reorder((const void *)d_ptr, ib, (const void *)vx,
ib * QK8_0 + iqs + j, v);
#ifdef GGML_SYCL_F16
dfloat2 t1{y[col + j + 0], y[col + j + 1]};
tmp += v * t1;
#else
tmp += v.x() * y[col + j + 0];
tmp += v.y() * y[col + j + 1];
#endif
}
}
// reduce
const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
for (int mask = mask_start; mask > 0; mask >>= 1) {
tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
if (tid == 0) {
#ifdef GGML_SYCL_F16
dst[row] = tmp.x() + tmp.y();
#else
dst[row] = tmp;
#endif
}
});
}
}
static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
float *dst, const int ncols,
const int nrows,
@@ -1122,7 +1219,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q8_0:
dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
dequantize_mul_mat_vec_q8_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
} else {
dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
}
break;
case GGML_TYPE_Q2_K:
dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);

View File

@@ -411,7 +411,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
assert(tensor->view_src->buffer->buft == buffer->buft);
return GGML_STATUS_SUCCESS;
}
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
!g_ggml_sycl_disable_optimize) {
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
tensor->extra = extra;
@@ -3254,6 +3254,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
return true;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:
@@ -3266,6 +3267,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
return true;
default:
return false;
@@ -3275,6 +3277,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:
return true;
@@ -3364,6 +3367,40 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
sycl_ext_free(stream, tmp_buf);
}
static void reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
dpct::queue_ptr stream) {
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
sycl::event copy_event;
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
if (!g_ggml_sycl_use_async_mem_op) {
copy_event.wait();
}
GGML_ASSERT((size % sizeof(block_q8_0) == 0));
GGML_ASSERT((offset % sizeof(block_q8_0) == 0));
int offset_blks = offset / sizeof(block_q8_0);
auto qs_ptr = data_device + offset_blks * QK8_0;
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows) + offset_blks;
auto reorder_event = stream->parallel_for(
size / sizeof(block_q8_0),
[=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
const block_q8_0* x = (const block_q8_0*)tmp_buf;
const int ib = i;
for (int j = 0; j < QK8_0; j++)
{
*((int8_t*)qs_ptr + ib * QK8_0 + j) = x[ib].qs[j];
}
*(d_ptr + ib) = x[ib].d;
});
if (!g_ggml_sycl_use_async_mem_op) {
reorder_event.wait_and_throw();
}
sycl_ext_free(stream, tmp_buf);
}
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
GGML_ASSERT(size % sizeof(block_q4_K) == 0);
GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
@@ -3460,6 +3497,9 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
case GGML_TYPE_Q4_0:
reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
break;
case GGML_TYPE_Q8_0:
reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream);
break;
case GGML_TYPE_Q4_K:
reorder_qw_q4_k(data_device, size, 0, stream);
break;

View File

@@ -679,6 +679,25 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
}
}
static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
const int nrows, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK8_0 == 0);
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
constexpr size_t num_subgroups = 16;
GGML_ASSERT(block_num_y % num_subgroups == 0);
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>>(vx, vy, dst, ncols, nrows,
nd_item);
});
});
}
static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
float *dst, const int ncols,
const int nrows,
@@ -1101,7 +1120,13 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q8_0:
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n");
reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
} else {
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
}
break;
case GGML_TYPE_Q2_K:
mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);

View File

@@ -105,6 +105,27 @@ template <> struct block_q_t<GGML_TYPE_Q6_K> {
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
};
template <> struct block_q_t<GGML_TYPE_Q8_0> {
struct traits {
static constexpr uint32_t qk = QK8_0; // 32
static constexpr uint32_t qi = QI8_0; // 8
static constexpr uint32_t qr = QR8_0; // 1
static constexpr uint32_t vdr_mmvq = 4;
};
// Q8_0 reorder layout: [qs0|qs1|...|qsN][d0|d1|...|dN]
// Each block has 32 int8 weights (32 bytes) followed by all scales
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
return { block_index * QK8_0, 0 };
}
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
return { (ncols * nrows) + block_index * sizeof(ggml_half), 0 };
}
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } // 1
};
} // namespace ggml_sycl_reordered
#endif // GGML_SYCL_QUANTS_HPP

View File

@@ -351,6 +351,46 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
};
};
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0> {
static constexpr ggml_type gtype = GGML_TYPE_Q8_0;
using q8_0_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q8_0>;
using q8_0_traits = typename q8_0_block::traits;
__dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const float & d8_0, const sycl::half2 & ds8) {
int sumi = 0;
#pragma unroll
for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) {
// Q8_0 values are signed int8, no nibble extraction needed
// Direct dp4a: each int packs 4 int8 values
sumi = dpct::dp4a(v[i], u[i], sumi);
}
const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();
// Q8_0 has no bias term (values are signed), so just scale
return d8_0 * sumi * ds8f.x();
}
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
const sycl::half2 * q8_1_ds, const int & iqs) {
const int8_t * bq8_0 = static_cast<const int8_t *>(vbq) + ibx_offset.first;
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset.first));
int v[q8_0_traits::vdr_mmvq];
int u[q8_0_traits::vdr_mmvq];
#pragma unroll
for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) {
v[i] = get_int_from_int8(bq8_0, iqs + i);
u[i] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i);
}
return vec_dot_q8_0_q8_1_impl(v, u, d, *q8_1_ds);
};
};
static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales,
const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1,
const int & iqs) {

View File

@@ -120,6 +120,10 @@ if (Vulkan_FOUND)
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
endif()
if (GGML_VULKAN_COPY_TESTS)
add_compile_definitions(GGML_VULKAN_COPY_TESTS)
endif()
# Set up toolchain for host compilation whether cross-compiling or not
if (CMAKE_CROSSCOMPILING)
if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN)

View File

@@ -1,9 +1,12 @@
#include "ggml-vulkan.h"
#include <vulkan/vulkan_core.h>
#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS)
#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS) || defined(GGML_VULKAN_COPY_TESTS)
#include <chrono>
#include "ggml-cpu.h"
#endif
#if defined(GGML_VULKAN_COPY_TESTS) && !defined(_WIN32)
#include <unistd.h>
#endif
// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1
@@ -591,6 +594,7 @@ struct vk_device_struct {
uint64_t suballocation_block_size;
uint64_t min_imported_host_pointer_alignment;
bool external_memory_host {};
bool external_semaphore_fd {};
bool fp16;
bool bf16;
bool pipeline_robustness;
@@ -1659,6 +1663,7 @@ struct ggml_vk_garbage_collector {
static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx);
static void ggml_vk_load_shaders(vk_device& device);
static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx);
static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size);
static bool vk_memory_logger_enabled = false;
@@ -3447,11 +3452,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
} else {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
}
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {
@@ -3459,6 +3472,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT1, _cm1)
}
#endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@@ -4870,6 +4887,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->memory_priority = true;
} else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) {
device->external_memory_host = true;
} else if (strcmp("VK_KHR_external_semaphore_fd", properties.extensionName) == 0) {
device->external_semaphore_fd = true;
#if defined(VK_EXT_shader_64bit_indexing)
} else if (strcmp("VK_EXT_shader_64bit_indexing", properties.extensionName) == 0) {
device->shader_64b_indexing = true;
@@ -5169,6 +5188,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
device_extensions.push_back("VK_EXT_external_memory_host");
}
if (device->external_semaphore_fd) {
device_extensions.push_back("VK_KHR_external_semaphore_fd");
}
#if defined(VK_EXT_shader_64bit_indexing)
VkPhysicalDeviceShader64BitIndexingFeaturesEXT shader_64bit_indexing_features {};
shader_64bit_indexing_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_64_BIT_INDEXING_FEATURES_EXT;
@@ -12618,7 +12641,654 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
}
#endif
#ifdef GGML_VULKAN_COPY_TESTS
// Cross-device copy benchmark
// Tests different approaches to copying data between two Vulkan devices.
// Build with -DGGML_VULKAN_COPY_TESTS and run any llama.cpp command with >= 2 Vulkan devices.
// Helper: allocate shared staging buffer importable by both devices
struct vk_shared_staging {
void * host_ptr = nullptr;
vk_buffer buf_dev0;
vk_buffer buf_dev1;
size_t size = 0;
bool alloc(vk_device & dev0, vk_device & dev1, size_t sz) {
size_t align = std::max(dev0->min_imported_host_pointer_alignment,
dev1->min_imported_host_pointer_alignment);
size = (sz + align - 1) & ~(align - 1);
#ifdef _WIN32
host_ptr = _aligned_malloc(size, align);
#else
if (posix_memalign(&host_ptr, align, size) != 0) { host_ptr = nullptr; }
#endif
if (!host_ptr) return false;
buf_dev0 = ggml_vk_buffer_from_host_ptr(dev0, host_ptr, size);
buf_dev1 = ggml_vk_buffer_from_host_ptr(dev1, host_ptr, size);
return buf_dev0 && buf_dev1;
}
void free_resources() {
ggml_vk_destroy_buffer(buf_dev0);
ggml_vk_destroy_buffer(buf_dev1);
#ifdef _WIN32
_aligned_free(host_ptr);
#else
free(host_ptr);
#endif
host_ptr = nullptr;
}
};
// Helper: run a benchmark and print results
static void vk_bench_print(const char * name, std::vector<double> & times, size_t size) {
std::sort(times.begin(), times.end());
double median = times[times.size() / 2];
double bw = (size / (1024.0 * 1024.0 * 1024.0)) / (median / 1000.0);
std::cerr << " " << std::left << std::setw(22) << name << " : "
<< std::fixed << std::setprecision(3) << median << " ms "
<< std::setprecision(2) << bw << " GB/s" << std::endl;
}
// Results stored per (method, size) for table output
struct vk_copy_result {
std::string method;
double ms;
double gbps;
};
static void ggml_vk_bench_pair(
vk_device & dev0, vk_device & dev1,
const std::vector<size_t> & test_sizes,
std::map<std::string, std::vector<vk_copy_result>> & results) {
const size_t num_it = 20;
const size_t warmup = 3;
const size_t max_size = test_sizes.back();
// Allocate buffers
vk_buffer buf_src = ggml_vk_create_buffer_check(dev0, max_size, {vk::MemoryPropertyFlagBits::eDeviceLocal});
vk_buffer buf_dst = ggml_vk_create_buffer_check(dev1, max_size, {vk::MemoryPropertyFlagBits::eDeviceLocal});
vk_buffer staging_src = ggml_vk_create_buffer_check(dev0, max_size,
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
vk_buffer staging_dst = ggml_vk_create_buffer_check(dev1, max_size,
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
// Fill source
{
vk_context subctx = ggml_vk_create_temporary_context(dev0->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dev0, subctx);
subctx->s->buffer->buf.fillBuffer(buf_src->buffer, 0, max_size, 0xDEADBEEF);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, dev0->fence);
VK_CHECK(dev0->device.waitForFences({ dev0->fence }, true, UINT64_MAX), "fill");
dev0->device.resetFences({ dev0->fence });
}
bool has_shared_staging = dev0->external_memory_host && dev1->external_memory_host;
bool has_syncfd = false;
#ifndef _WIN32
if (dev0->external_semaphore_fd && dev1->external_semaphore_fd) {
vk::PhysicalDeviceExternalSemaphoreInfo query{};
query.handleType = vk::ExternalSemaphoreHandleTypeFlagBits::eSyncFd;
auto p0 = dev0->physical_device.getExternalSemaphoreProperties(query);
auto p1 = dev1->physical_device.getExternalSemaphoreProperties(query);
has_syncfd =
(p0.externalSemaphoreFeatures & vk::ExternalSemaphoreFeatureFlagBits::eExportable) &&
(p0.compatibleHandleTypes & vk::ExternalSemaphoreHandleTypeFlagBits::eSyncFd) &&
(p1.externalSemaphoreFeatures & vk::ExternalSemaphoreFeatureFlagBits::eImportable) &&
(p1.compatibleHandleTypes & vk::ExternalSemaphoreHandleTypeFlagBits::eSyncFd);
}
#endif
// Helper to record a result
auto record = [&](const std::string & method, size_t size, std::vector<double> & times) {
std::sort(times.begin(), times.end());
double median = times[times.size() / 2];
double bw = (size / (1024.0 * 1024.0 * 1024.0)) / (median / 1000.0);
results[method].push_back({ method, median, bw });
};
// Helper to record a skipped size (sentinel: negative ms)
auto skip = [&](const std::string & method) {
results[method].push_back({ method, -1.0, -1.0 });
};
for (size_t size : test_sizes) {
// =================================================================
// 1. Baseline: current sync double-hop (separate staging buffers + memcpy)
// =================================================================
{
std::vector<double> times;
for (size_t i = 0; i < num_it + warmup; i++) {
auto begin = std::chrono::high_resolution_clock::now();
{
std::lock_guard<std::recursive_mutex> guard(dev0->mutex);
vk_context subctx = ggml_vk_create_temporary_context(dev0->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dev0, subctx);
ggml_vk_buffer_copy_async(subctx, staging_src, 0, buf_src, 0, size);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, dev0->fence);
VK_CHECK(dev0->device.waitForFences({ dev0->fence }, true, UINT64_MAX), "baseline hop1");
dev0->device.resetFences({ dev0->fence });
}
memcpy(staging_dst->ptr, staging_src->ptr, size);
{
std::lock_guard<std::recursive_mutex> guard(dev1->mutex);
vk_context subctx = ggml_vk_create_temporary_context(dev1->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dev1, subctx);
ggml_vk_buffer_copy_async(subctx, buf_dst, 0, staging_dst, 0, size);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, dev1->fence);
VK_CHECK(dev1->device.waitForFences({ dev1->fence }, true, UINT64_MAX), "baseline hop2");
dev1->device.resetFences({ dev1->fence });
}
auto end = std::chrono::high_resolution_clock::now();
if (i >= warmup) times.push_back(std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count() / 1000.0);
}
record("baseline", size, times);
}
// =================================================================
// 2. Diagnostics: individual hop timings
// =================================================================
{
std::vector<double> times;
for (size_t i = 0; i < num_it + warmup; i++) {
auto begin = std::chrono::high_resolution_clock::now();
{
std::lock_guard<std::recursive_mutex> guard(dev0->mutex);
vk_context subctx = ggml_vk_create_temporary_context(dev0->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dev0, subctx);
ggml_vk_buffer_copy_async(subctx, staging_src, 0, buf_src, 0, size);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, dev0->fence);
VK_CHECK(dev0->device.waitForFences({ dev0->fence }, true, UINT64_MAX), "diag hop1");
dev0->device.resetFences({ dev0->fence });
}
auto end = std::chrono::high_resolution_clock::now();
if (i >= warmup) times.push_back(std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count() / 1000.0);
}
record("hop1_only", size, times);
}
{
std::vector<double> times;
for (size_t i = 0; i < num_it + warmup; i++) {
auto begin = std::chrono::high_resolution_clock::now();
{
std::lock_guard<std::recursive_mutex> guard(dev1->mutex);
vk_context subctx = ggml_vk_create_temporary_context(dev1->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dev1, subctx);
ggml_vk_buffer_copy_async(subctx, buf_dst, 0, staging_dst, 0, size);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, dev1->fence);
VK_CHECK(dev1->device.waitForFences({ dev1->fence }, true, UINT64_MAX), "diag hop2");
dev1->device.resetFences({ dev1->fence });
}
auto end = std::chrono::high_resolution_clock::now();
if (i >= warmup) times.push_back(std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count() / 1000.0);
}
record("hop2_only", size, times);
}
// =================================================================
// 3. Shared staging: single host buffer imported into both devices
// =================================================================
if (has_shared_staging) {
vk_shared_staging stg;
if (stg.alloc(dev0, dev1, size)) {
std::vector<double> times;
for (size_t i = 0; i < num_it + warmup; i++) {
auto begin = std::chrono::high_resolution_clock::now();
{
std::lock_guard<std::recursive_mutex> guard(dev0->mutex);
vk_context subctx = ggml_vk_create_temporary_context(dev0->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dev0, subctx);
ggml_vk_buffer_copy_async(subctx, stg.buf_dev0, 0, buf_src, 0, size);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, dev0->fence);
VK_CHECK(dev0->device.waitForFences({ dev0->fence }, true, UINT64_MAX), "shared hop1");
dev0->device.resetFences({ dev0->fence });
}
{
std::lock_guard<std::recursive_mutex> guard(dev1->mutex);
vk_context subctx = ggml_vk_create_temporary_context(dev1->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dev1, subctx);
ggml_vk_buffer_copy_async(subctx, buf_dst, 0, stg.buf_dev1, 0, size);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, dev1->fence);
VK_CHECK(dev1->device.waitForFences({ dev1->fence }, true, UINT64_MAX), "shared hop2");
dev1->device.resetFences({ dev1->fence });
}
auto end = std::chrono::high_resolution_clock::now();
if (i >= warmup) times.push_back(std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count() / 1000.0);
}
record("shared_staging", size, times);
} else {
std::cerr << " shared_staging : SKIPPED (import failed)" << std::endl;
}
stg.free_resources();
}
// =================================================================
// 4. Chunked pipeline: split into N chunks, overlap hop1/hop2
// via full-duplex PCIe. Vary chunk count to find optimum.
// =================================================================
if (has_shared_staging) {
for (size_t n_chunks : { 2, 4, 8 }) {
char cname[32];
snprintf(cname, sizeof(cname), "chunked_%zu", n_chunks);
if (size < n_chunks * 4096) { skip(cname); continue; }
size_t align = std::max(dev0->min_imported_host_pointer_alignment,
dev1->min_imported_host_pointer_alignment);
size_t chunk_data = size / n_chunks;
size_t chunk_aligned = (chunk_data + align - 1) & ~(align - 1);
vk_shared_staging stg;
if (!stg.alloc(dev0, dev1, chunk_aligned * n_chunks)) {
std::cerr << " chunked_" << n_chunks << " : SKIPPED (import failed)" << std::endl;
stg.free_resources();
continue;
}
// Per-chunk timeline semaphores
std::vector<vk::Semaphore> chunk_sems(n_chunks);
std::vector<uint64_t> sem_vals(n_chunks, 0);
for (size_t c = 0; c < n_chunks; c++) {
vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
vk::SemaphoreCreateInfo sci{};
sci.setPNext(&tci);
chunk_sems[c] = dev0->device.createSemaphore(sci);
}
std::vector<double> times;
for (size_t iter = 0; iter < num_it + warmup; iter++) {
auto begin = std::chrono::high_resolution_clock::now();
// Submit all hop1s upfront
for (size_t c = 0; c < n_chunks; c++) {
size_t off_src = c * chunk_data;
size_t off_stg = c * chunk_aligned;
size_t csz = (c == n_chunks - 1) ? (size - c * chunk_data) : chunk_data;
vk_context subctx = ggml_vk_create_temporary_context(dev0->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dev0, subctx);
ggml_vk_buffer_copy_async(subctx, stg.buf_dev0, off_stg, buf_src, off_src, csz);
sem_vals[c]++;
subctx->s->signal_semaphores.push_back({ chunk_sems[c], sem_vals[c] });
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, {});
}
// Per-chunk: CPU wait hop1, submit hop2
for (size_t c = 0; c < n_chunks; c++) {
size_t off_dst = c * chunk_data;
size_t off_stg = c * chunk_aligned;
size_t csz = (c == n_chunks - 1) ? (size - c * chunk_data) : chunk_data;
vk::SemaphoreWaitInfo swi{vk::SemaphoreWaitFlags{}, chunk_sems[c], sem_vals[c]};
VK_CHECK(dev0->device.waitSemaphores(swi, UINT64_MAX), "chunked sem wait");
vk_context subctx = ggml_vk_create_temporary_context(dev1->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dev1, subctx);
ggml_vk_buffer_copy_async(subctx, buf_dst, off_dst, stg.buf_dev1, off_stg, csz);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, (c == n_chunks - 1) ? dev1->fence : vk::Fence{});
}
VK_CHECK(dev1->device.waitForFences({ dev1->fence }, true, UINT64_MAX), "chunked final");
dev1->device.resetFences({ dev1->fence });
auto end = std::chrono::high_resolution_clock::now();
if (iter >= warmup) times.push_back(std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count() / 1000.0);
}
char name[32];
snprintf(name, sizeof(name), "chunked_%zu", n_chunks);
record(name, size, times);
for (size_t c = 0; c < n_chunks; c++) dev0->device.destroySemaphore(chunk_sems[c]);
stg.free_resources();
}
}
// =================================================================
// 5. sync_fd async: fully GPU-synchronised via Linux sync_file
// =================================================================
#ifndef _WIN32
if (has_shared_staging && has_syncfd) {
vk_shared_staging stg;
if (stg.alloc(dev0, dev1, size)) {
std::vector<double> times;
bool run_ok = true;
for (size_t i = 0; i < num_it + warmup && run_ok; i++) {
auto begin = std::chrono::high_resolution_clock::now();
vk::ExportSemaphoreCreateInfo esci{};
esci.handleTypes = vk::ExternalSemaphoreHandleTypeFlagBits::eSyncFd;
vk::SemaphoreCreateInfo sci{};
sci.setPNext(&esci);
vk::Semaphore sem_dev0 = dev0->device.createSemaphore(sci);
// Hop 1 + signal
{
vk_context subctx = ggml_vk_create_temporary_context(dev0->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dev0, subctx);
ggml_vk_buffer_copy_async(subctx, stg.buf_dev0, 0, buf_src, 0, size);
subctx->s->signal_semaphores.push_back({ sem_dev0, 0 });
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, {});
}
// Export + import sync_fd
int sync_fd = -1;
try {
vk::SemaphoreGetFdInfoKHR gi{};
gi.semaphore = sem_dev0;
gi.handleType = vk::ExternalSemaphoreHandleTypeFlagBits::eSyncFd;
sync_fd = dev0->device.getSemaphoreFdKHR(gi);
} catch (vk::SystemError& e) {
std::cerr << " syncfd_async : SKIPPED (export: " << e.what() << ")" << std::endl;
dev0->device.destroySemaphore(sem_dev0);
run_ok = false; break;
}
vk::Semaphore sem_dev1 = dev1->device.createSemaphore({});
try {
vk::ImportSemaphoreFdInfoKHR ii{};
ii.semaphore = sem_dev1;
ii.handleType = vk::ExternalSemaphoreHandleTypeFlagBits::eSyncFd;
ii.flags = vk::SemaphoreImportFlagBits::eTemporary;
ii.fd = sync_fd;
dev1->device.importSemaphoreFdKHR(ii);
} catch (vk::SystemError& e) {
std::cerr << " syncfd_async : SKIPPED (import: " << e.what() << ")" << std::endl;
dev0->device.destroySemaphore(sem_dev0);
dev1->device.destroySemaphore(sem_dev1);
close(sync_fd);
run_ok = false; break;
}
// Hop 2 with GPU-side wait
{
vk_context subctx = ggml_vk_create_temporary_context(dev1->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dev1, subctx);
subctx->s->wait_semaphores.push_back({ sem_dev1, 0 });
ggml_vk_buffer_copy_async(subctx, buf_dst, 0, stg.buf_dev1, 0, size);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, dev1->fence);
VK_CHECK(dev1->device.waitForFences({ dev1->fence }, true, UINT64_MAX), "syncfd final");
dev1->device.resetFences({ dev1->fence });
}
dev0->device.destroySemaphore(sem_dev0);
dev1->device.destroySemaphore(sem_dev1);
auto end = std::chrono::high_resolution_clock::now();
if (i >= warmup) times.push_back(std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count() / 1000.0);
}
if (run_ok) record("syncfd_async", size, times);
} else {
std::cerr << " syncfd_async : SKIPPED (import failed)" << std::endl;
}
stg.free_resources();
}
// =================================================================
// 6. sync_fd chunked: chunked pipeline with GPU-side sync_fd
// between hops (no CPU waits between chunks)
// =================================================================
if (has_shared_staging && has_syncfd) {
for (size_t n_chunks : { 2, 4, 8 }) {
char scname[48];
snprintf(scname, sizeof(scname), "syncfd_chunked_%zu", n_chunks);
if (size < n_chunks * 4096) { skip(scname); continue; }
size_t align = std::max(dev0->min_imported_host_pointer_alignment,
dev1->min_imported_host_pointer_alignment);
size_t chunk_data = size / n_chunks;
size_t chunk_aligned = (chunk_data + align - 1) & ~(align - 1);
vk_shared_staging stg;
if (!stg.alloc(dev0, dev1, chunk_aligned * n_chunks)) {
std::cerr << " syncfd_chunked_" << n_chunks << " : SKIPPED (import failed)" << std::endl;
stg.free_resources();
continue;
}
std::vector<double> times;
bool run_ok = true;
for (size_t iter = 0; iter < num_it + warmup && run_ok; iter++) {
auto begin = std::chrono::high_resolution_clock::now();
// Create per-chunk exportable semaphores
std::vector<vk::Semaphore> sems_dev0(n_chunks);
for (size_t c = 0; c < n_chunks; c++) {
vk::ExportSemaphoreCreateInfo esci{};
esci.handleTypes = vk::ExternalSemaphoreHandleTypeFlagBits::eSyncFd;
vk::SemaphoreCreateInfo sci{};
sci.setPNext(&esci);
sems_dev0[c] = dev0->device.createSemaphore(sci);
}
// Submit all hop1s with per-chunk signal
for (size_t c = 0; c < n_chunks; c++) {
size_t off_src = c * chunk_data;
size_t off_stg = c * chunk_aligned;
size_t csz = (c == n_chunks - 1) ? (size - c * chunk_data) : chunk_data;
vk_context subctx = ggml_vk_create_temporary_context(dev0->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dev0, subctx);
ggml_vk_buffer_copy_async(subctx, stg.buf_dev0, off_stg, buf_src, off_src, csz);
subctx->s->signal_semaphores.push_back({ sems_dev0[c], 0 });
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, {});
}
// Export all sync_fds and import on dev1, submit hop2s
for (size_t c = 0; c < n_chunks && run_ok; c++) {
size_t off_dst = c * chunk_data;
size_t off_stg = c * chunk_aligned;
size_t csz = (c == n_chunks - 1) ? (size - c * chunk_data) : chunk_data;
int sync_fd = -1;
try {
vk::SemaphoreGetFdInfoKHR gi{};
gi.semaphore = sems_dev0[c];
gi.handleType = vk::ExternalSemaphoreHandleTypeFlagBits::eSyncFd;
sync_fd = dev0->device.getSemaphoreFdKHR(gi);
} catch (vk::SystemError& e) {
char nm[48]; snprintf(nm, sizeof(nm), "syncfd_chunked_%zu", n_chunks);
std::cerr << " " << nm << " : SKIPPED (export: " << e.what() << ")" << std::endl;
run_ok = false; break;
}
vk::Semaphore sem_dev1 = dev1->device.createSemaphore({});
try {
vk::ImportSemaphoreFdInfoKHR ii{};
ii.semaphore = sem_dev1;
ii.handleType = vk::ExternalSemaphoreHandleTypeFlagBits::eSyncFd;
ii.flags = vk::SemaphoreImportFlagBits::eTemporary;
ii.fd = sync_fd;
dev1->device.importSemaphoreFdKHR(ii);
} catch (vk::SystemError& e) {
char nm[48]; snprintf(nm, sizeof(nm), "syncfd_chunked_%zu", n_chunks);
std::cerr << " " << nm << " : SKIPPED (import: " << e.what() << ")" << std::endl;
dev1->device.destroySemaphore(sem_dev1);
close(sync_fd);
run_ok = false; break;
}
vk_context subctx = ggml_vk_create_temporary_context(dev1->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dev1, subctx);
subctx->s->wait_semaphores.push_back({ sem_dev1, 0 });
ggml_vk_buffer_copy_async(subctx, buf_dst, off_dst, stg.buf_dev1, off_stg, csz);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, (c == n_chunks - 1) ? dev1->fence : vk::Fence{});
dev1->device.destroySemaphore(sem_dev1);
}
if (run_ok) {
VK_CHECK(dev1->device.waitForFences({ dev1->fence }, true, UINT64_MAX), "syncfd_chunked final");
dev1->device.resetFences({ dev1->fence });
}
for (size_t c = 0; c < n_chunks; c++) dev0->device.destroySemaphore(sems_dev0[c]);
auto end = std::chrono::high_resolution_clock::now();
if (run_ok && iter >= warmup) times.push_back(std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count() / 1000.0);
}
if (run_ok) {
char name[48];
snprintf(name, sizeof(name), "syncfd_chunked_%zu", n_chunks);
record(name, size, times);
}
stg.free_resources();
}
}
#endif
}
ggml_vk_destroy_buffer(buf_src);
ggml_vk_destroy_buffer(buf_dst);
ggml_vk_destroy_buffer(staging_src);
ggml_vk_destroy_buffer(staging_dst);
}
static void ggml_vk_test_cross_device_copy(ggml_backend_vk_context * ctx) {
ggml_vk_instance_init();
const size_t n_devices = vk_instance.device_indices.size();
if (n_devices < 2) {
std::cerr << "COPY TEST: Need at least 2 Vulkan devices, found " << n_devices << std::endl;
return;
}
// List devices
std::cerr << "\n=== Vulkan Devices ===" << std::endl;
std::vector<vk_device> devices(n_devices);
for (size_t i = 0; i < n_devices; i++) {
devices[i] = ggml_vk_get_device(i);
std::cerr << " [" << i << "] " << devices[i]->name << std::endl;
}
const std::vector<size_t> test_sizes = {
4096, // 4 KB
256 * 1024, // 256 KB
1 * 1024 * 1024, // 1 MB
16 * 1024 * 1024, // 16 MB
64 * 1024 * 1024, // 64 MB
256 * 1024 * 1024, // 256 MB
};
// Collect results: results[pair_label][method_name] = vector of vk_copy_result (one per size)
struct pair_results {
std::string label;
std::map<std::string, std::vector<vk_copy_result>> methods;
};
std::vector<pair_results> all_results;
// Run benchmarks for all ordered pairs
for (size_t i = 0; i < n_devices; i++) {
for (size_t j = 0; j < n_devices; j++) {
if (i == j) continue;
std::string label = devices[i]->name + " -> " + devices[j]->name;
std::cerr << "\n\n=== " << label << " ===" << std::endl;
pair_results pr;
pr.label = label;
ggml_vk_bench_pair(devices[i], devices[j], test_sizes, pr.methods);
all_results.push_back(std::move(pr));
}
}
// Output markdown tables: one table per method
// Collect all method names
std::vector<std::string> method_order;
if (!all_results.empty()) {
// Use first pair's method order as canonical
for (auto & [method, _] : all_results[0].methods) {
method_order.push_back(method);
}
// Add any methods from other pairs not in the first
for (auto & pr : all_results) {
for (auto & [method, _] : pr.methods) {
if (std::find(method_order.begin(), method_order.end(), method) == method_order.end()) {
method_order.push_back(method);
}
}
}
}
std::cerr << "\n\n# Cross-Device Copy Benchmark Results\n" << std::endl;
for (auto & method : method_order) {
std::cerr << "## " << method << "\n" << std::endl;
// Header: | Direction | 4KB | 256KB | ... |
std::cerr << "| Direction |";
for (size_t s : test_sizes) {
if (s < 1024 * 1024) {
std::cerr << " " << s / 1024 << " KB |";
} else {
std::cerr << " " << s / (1024 * 1024) << " MB |";
}
}
std::cerr << std::endl;
// Separator
std::cerr << "|---|";
for (size_t s = 0; s < test_sizes.size(); s++) {
std::cerr << "---|";
GGML_UNUSED(s);
}
std::cerr << std::endl;
// Data rows
for (auto & pr : all_results) {
std::cerr << "| " << pr.label << " |";
auto it = pr.methods.find(method);
if (it != pr.methods.end() && it->second.size() == test_sizes.size()) {
for (auto & r : it->second) {
if (r.ms < 0) {
std::cerr << " - |";
} else {
std::cerr << " " << std::fixed << std::setprecision(1) << r.ms << " ms (" << std::setprecision(1) << r.gbps << " GB/s) |";
}
}
} else {
for (size_t s = 0; s < test_sizes.size(); s++) {
std::cerr << " - |";
GGML_UNUSED(s);
}
}
std::cerr << std::endl;
}
std::cerr << std::endl;
}
GGML_ABORT("GGML_VULKAN_COPY_TESTS completed");
GGML_UNUSED(ctx);
}
#endif
static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx) {
#if defined(GGML_VULKAN_COPY_TESTS)
ggml_vk_test_cross_device_copy(ctx);
#endif
#if defined(GGML_VULKAN_RUN_TESTS)
const std::vector<size_t> vals {
512, 512, 128,
@@ -15331,11 +16001,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
// supported in scalar and coopmat2 paths
break;
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_IQ4_NL:
// supported in scalar and coopmat2 paths
break;
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
//case GGML_TYPE_Q2_K:
//case GGML_TYPE_Q3_K:
@@ -15350,12 +16021,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
//case GGML_TYPE_IQ3_XXS:
//case GGML_TYPE_IQ3_S:
//case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
// currently supported only in coopmat2 path
if (!coopmat2) {
return false;
}
break;
default:
return false;
}

View File

@@ -110,6 +110,97 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
#elif defined(DATA_A_Q4_1)
#define BLOCK_BYTE_SIZE 20
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
#ifdef DATA_A_Q4_1
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m);
#else
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f));
#endif
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
#ifdef DATA_A_Q4_1
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m);
#else
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f));
#endif
}
}
#endif
#if defined(DATA_A_Q5_0)
#define BLOCK_BYTE_SIZE 22
#elif defined(DATA_A_Q5_1)
#define BLOCK_BYTE_SIZE 24
#endif
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
#ifdef DATA_A_Q5_1
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
#else
uint qh = uint(k_packed.k_data_packed16[a_offset + ib].qh[0]) | (uint(k_packed.k_data_packed16[a_offset + ib].qh[1]) << 16);
#endif
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f);
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
#ifdef DATA_A_Q5_1
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m);
#else
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f));
#endif
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
#ifdef DATA_A_Q5_1
uint qh = v_packed.v_data_packed16[a_offset + ib].qh;
#else
uint qh = uint(v_packed.v_data_packed16[a_offset + ib].qh[0]) | (uint(v_packed.v_data_packed16[a_offset + ib].qh[1]) << 16);
#endif
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f);
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
#ifdef DATA_A_Q5_1
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m);
#else
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f));
#endif
}
}
#endif
#if defined(DATA_A_IQ4_NL)
#define BLOCK_BYTE_SIZE 18
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
@@ -119,7 +210,11 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
vui_lo >>= shift;
vui_hi >>= shift;
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(
kvalues_iq4nl[vui_lo & 0xF],
kvalues_iq4nl[(vui_lo >> 8) & 0xF],
kvalues_iq4nl[vui_hi & 0xF],
kvalues_iq4nl[(vui_hi >> 8) & 0xF]);
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
@@ -127,11 +222,14 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
vui_lo >>= shift;
vui_hi >>= shift;
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(
kvalues_iq4nl[vui_lo & 0xF],
kvalues_iq4nl[(vui_lo >> 8) & 0xF],
kvalues_iq4nl[vui_hi & 0xF],
kvalues_iq4nl[(vui_hi >> 8) & 0xF]);
}
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {

View File

@@ -137,6 +137,7 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
pid_t pid = fork();
if (pid < 0) {
std::cerr << strerror(errno) << "\n";
throw std::runtime_error("Failed to fork process");
}
@@ -655,7 +656,7 @@ void process_shaders() {
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
} else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
@@ -666,7 +667,7 @@ void process_shaders() {
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
} else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);

View File

@@ -16,7 +16,6 @@
#include <webgpu/webgpu_cpp.h>
#include <atomic>
#include <condition_variable>
#include <cstdint>
#include <cstring>
#ifdef GGML_WEBGPU_GPU_PROFILE
@@ -25,7 +24,6 @@
#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE)
# include <iostream>
#endif
#include <map>
#include <memory>
#include <mutex>
#include <optional>
@@ -81,13 +79,13 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
/* Constants */
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
#define WEBGPU_NUM_PARAM_SLOTS \
(WEBGPU_COMMAND_SUBMIT_BATCH_SIZE + 10) // a few extra for safety, since some operations may need multiple slots
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 100
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 32u
#define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6)
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
// For operations which process a row in parallel, this seems like a reasonable
// default
@@ -252,6 +250,8 @@ struct webgpu_global_context_struct {
wgpu::Adapter adapter;
wgpu::Device device;
wgpu::Queue queue;
uint32_t command_submit_batch_size = WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE;
uint32_t max_inflight_batches = UINT32_MAX;
webgpu_capabilities capabilities;
// Shared buffer to move data from device to host
@@ -417,16 +417,72 @@ static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context &
}
#endif
template <typename T>
static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status,
T callback_status,
T success_status,
const char * wait_name,
const char * failure_name,
const char * callback_message) {
if (wait_status == wgpu::WaitStatus::TimedOut) {
GGML_ABORT("ggml_webgpu: %s timed out after %u ms\n", wait_name, WEBGPU_RUNTIME_WAIT_TIMEOUT_MS);
}
if (wait_status == wgpu::WaitStatus::Error) {
GGML_ABORT("ggml_webgpu: %s failed\n", wait_name);
}
if (callback_status != success_status) {
GGML_ABORT("ggml_webgpu: %s failed with status %d: %s\n", failure_name, static_cast<int>(callback_status),
callback_message);
}
}
#ifdef __EMSCRIPTEN__
// iOS browsers seem to have very strict limits on the number of in-flight GPU commands, so we need to throttle to avoid failures.
EM_JS(int, ggml_webgpu_is_ios_browser, (), {
const ua = navigator.userAgent;
return (ua.includes('iPhone') || ua.includes('iPad')) ? 1 : 0;
});
#endif
static uint32_t ggml_backend_webgpu_get_max_inflight_batches(const wgpu::AdapterInfo & info) {
#ifdef __EMSCRIPTEN__
if (ggml_webgpu_is_ios_browser()) {
return 1;
}
#else
GGML_UNUSED(info);
#endif
return UINT32_MAX;
}
static uint32_t ggml_backend_webgpu_get_command_submit_batch_size(const wgpu::AdapterInfo & info) {
#ifdef __EMSCRIPTEN__
if (ggml_webgpu_is_ios_browser()) {
return 16;
}
#else
GGML_UNUSED(info);
#endif
return WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE;
}
static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) {
ctx->instance.WaitAny(
ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
std::string(message).c_str());
}
}),
UINT64_MAX);
wgpu::QueueWorkDoneStatus callback_status = wgpu::QueueWorkDoneStatus::Error;
std::string callback_message;
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
callback_status = status;
callback_message = std::string(message);
}),
WEBGPU_RUNTIME_WAIT_TIMEOUT_NS);
ggml_backend_webgpu_check_wait_status(wait_status, callback_status, wgpu::QueueWorkDoneStatus::Success,
"Queue wait", "Queue work", callback_message.c_str());
}
static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
@@ -434,14 +490,31 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
wgpu::MapMode mode,
size_t offset,
size_t size) {
ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
message.data);
}
}),
UINT64_MAX);
wgpu::MapAsyncStatus callback_status = wgpu::MapAsyncStatus::Error;
std::string callback_message;
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
[&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) {
callback_status = status;
callback_message = std::string(message);
}),
WEBGPU_RUNTIME_WAIT_TIMEOUT_NS);
ggml_backend_webgpu_check_wait_status(wait_status, callback_status, wgpu::MapAsyncStatus::Success,
"Buffer map wait", "Buffer map", callback_message.c_str());
}
static void ggml_backend_webgpu_submit_commands(webgpu_context & ctx,
const wgpu::CommandBuffer commands,
uint32_t & num_inflight_batches) {
if (num_inflight_batches >= ctx->global_ctx->max_inflight_batches) {
ggml_backend_webgpu_wait_queue(ctx->global_ctx);
num_inflight_batches = 0;
}
ctx->global_ctx->queue.Submit(1, &commands);
num_inflight_batches++;
}
#ifdef GGML_WEBGPU_DEBUG
@@ -2871,9 +2944,10 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
#ifdef GGML_WEBGPU_GPU_PROFILE
std::vector<wgpu::FutureWaitInfo> profile_futures;
#endif
uint32_t num_batched_kernels = 0;
bool contains_set_rows = false;
wgpu::CommandEncoder batch_encoder = ctx->global_ctx->device.CreateCommandEncoder();
uint32_t num_batched_kernels = 0;
uint32_t num_inflight_batches = 0;
bool contains_set_rows = false;
wgpu::CommandEncoder batch_encoder = ctx->global_ctx->device.CreateCommandEncoder();
for (int i = 0; i < cgraph->n_nodes; i++) {
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
@@ -2884,10 +2958,10 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
num_batched_kernels += cmd.value().num_kernels;
}
if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
if (num_batched_kernels >= ctx->global_ctx->command_submit_batch_size) {
num_batched_kernels = 0;
wgpu::CommandBuffer batch_commands = batch_encoder.Finish();
ctx->global_ctx->queue.Submit(1, &batch_commands);
ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches);
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures);
#endif
@@ -2898,7 +2972,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
}
if (!commands.empty()) {
wgpu::CommandBuffer batch_commands = batch_encoder.Finish();
ctx->global_ctx->queue.Submit(1, &batch_commands);
ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches);
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures);
#endif
@@ -2912,7 +2986,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
ctx->set_rows_host_error_buf.GetSize());
wgpu::CommandBuffer set_rows_commands = encoder.Finish();
ctx->global_ctx->queue.Submit(1, &set_rows_commands);
ggml_backend_webgpu_submit_commands(ctx, set_rows_commands, num_inflight_batches);
}
ggml_backend_webgpu_wait_queue(ctx->global_ctx);
@@ -3363,6 +3437,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
}
#endif
ctx->webgpu_global_ctx->adapter.GetInfo(&info);
ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(info);
ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(info);
wgpu::SupportedFeatures features;
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
// we require f16 support
@@ -3483,8 +3559,10 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
webgpu_ctx->param_arena.init(webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, WEBGPU_NUM_PARAM_SLOTS,
webgpu_ctx->global_ctx->capabilities.limits.minUniformBufferOffsetAlignment);
webgpu_ctx->param_arena.init(
webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES,
webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN,
webgpu_ctx->global_ctx->capabilities.limits.minUniformBufferOffsetAlignment);
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");

View File

@@ -1441,7 +1441,7 @@ class TensorNameMap:
"visual.blocks.{bid}.attn.qkv", # qwen3vl
"model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
"model.vision_model.transformer.layers.{bid}.self_attn.qkv_proj", # Deepseek-OCR CLIP
"vision_tower.encoder.blocks.{bid}.wqkv" # Kimi-K2.5
"vision_tower.encoder.blocks.{bid}.wqkv", # Kimi-K2.5
"vision_model.radio_model.model.blocks.{bid}.attn.qkv", # Nemotron Nano v2 VL
),

File diff suppressed because it is too large Load Diff

View File

@@ -585,8 +585,6 @@ struct LLM_TN_IMPL {
const int bid;
const int xid;
const std::set<llm_tensor> model_tensors;
LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid);
std::string str() const;

View File

@@ -511,6 +511,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
if (self_v_rot) {
mctx->get_base()->set_input_v_rot(self_v_rot);
}
if (self_k_rot_swa) {
mctx->get_swa()->set_input_k_rot(self_k_rot_swa);
}
if (self_v_rot_swa) {
mctx->get_swa()->set_input_v_rot(self_v_rot_swa);
}
}
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
@@ -681,6 +689,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
}
if (inp_attn->self_k_rot_swa) {
attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa);
}
if (inp_attn->self_v_rot_swa) {
attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
@@ -2233,15 +2249,20 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v_mla,
float kq_scale,
int il) const {
if (inp->self_k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
const bool is_swa = hparams.is_swa(il);
auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot;
auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot;
if (k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot);
if (k_cur) {
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot);
}
}
if (inp->self_v_rot) {
if (v_rot) {
if (v_cur) {
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot);
}
}
@@ -2259,8 +2280,6 @@ ggml_tensor * llm_graph_context::build_attn(
const auto * mctx_iswa = inp->mctx;
const bool is_swa = hparams.is_swa(il);
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
// optionally store to KV cache
@@ -2285,8 +2304,8 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (inp->self_v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
if (v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, v_rot);
}
if (wo) {
@@ -2388,6 +2407,9 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0);
inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0);
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
}

View File

@@ -308,7 +308,7 @@ public:
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
// note: assumes v_rot^ == I
// note: assumes v_rot^2 == I
ggml_tensor * self_k_rot = nullptr;
ggml_tensor * self_v_rot = nullptr;
@@ -388,10 +388,12 @@ public:
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
// note: using same rotation matrices for both base and swa cache
ggml_tensor * self_k_rot = nullptr;
ggml_tensor * self_v_rot = nullptr;
ggml_tensor * self_k_rot_swa = nullptr;
ggml_tensor * self_v_rot_swa = nullptr;
const llama_hparams hparams;
const llama_cparams cparams;

View File

@@ -169,6 +169,18 @@ llama_kv_cache::llama_kv_cache(
continue;
}
if (n_embd_head_k_all == 0) {
n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il);
} else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) {
n_embd_head_k_all = -1;
}
if (n_embd_head_v_all == 0) {
n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il);
} else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) {
n_embd_head_v_all = -1;
}
// [TAG_V_CACHE_VARIABLE]
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
@@ -276,23 +288,23 @@ llama_kv_cache::llama_kv_cache(
attn_rot_k =
!attn_rot_disable &&
n_embd_head_k_all > 0 &&
ggml_is_quantized(type_k) &&
!hparams.is_n_embd_k_gqa_variable() &&
hparams.n_embd_head_k() % 64 == 0;
attn_rot_v =
!attn_rot_disable &&
n_embd_head_v_all > 0 &&
ggml_is_quantized(type_v) &&
!hparams.is_n_embd_v_gqa_variable() &&
hparams.n_embd_head_v() % 64 == 0;
LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k);
LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v);
LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all);
LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all);
// pre-compute the haramard matrices and keep them in host memory
// TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers
if (attn_rot_k || attn_rot_v) {
for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) {
for (int64_t n = 64; n <= std::max(n_embd_head_k_all, n_embd_head_v_all); n *= 2) {
attn_rot_hadamard[n] = std::vector<float>(n*n);
ggml_init_params params = {
@@ -1308,7 +1320,7 @@ ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const {
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
do {
nrot *= 2;
} while (hparams.n_embd_head_k() % nrot == 0);
} while (n_embd_head_k_all % nrot == 0);
nrot /= 2;
res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);

View File

@@ -239,6 +239,11 @@ private:
bool attn_rot_k = false;
bool attn_rot_v = false;
// if all layers participating in the cache have constant head size, the value is stored here
// otherwise the value is -1
int32_t n_embd_head_k_all = 0;
int32_t n_embd_head_v_all = 0;
// pre-computed hadamard martrices
std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard;

View File

@@ -470,6 +470,141 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return bpe_offsets;
}
// Qwen2 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
static std::vector<size_t> unicode_regex_split_custom_qwen2(const std::string & text, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
const auto cpts = unicode_cpts_from_utf8(text);
size_t start = 0;
for (auto offset : offsets) {
const size_t offset_ini = start;
const size_t offset_end = start + offset;
assert(offset_end <= cpts.size());
start = offset_end;
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
};
size_t _prev_end = offset_ini;
auto _add_token = [&] (const size_t end) -> size_t {
assert(_prev_end <= end && end <= offset_end);
size_t len = end - _prev_end;
if (len > 0) {
bpe_offsets.push_back(len);
}
_prev_end = end;
//if (len > 0) {
// std::string s = "";
// for(size_t p = end-len; p < end; p++)
// s += unicode_cpt_to_utf8(cpts[p]);
// printf(">>> '%s'\n", s.c_str());
//}
return len;
};
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const uint32_t cpt = _get_cpt(pos);
const auto flags = _get_flags(pos);
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
if (cpt == '\'' && pos+1 < offset_end) {
uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
pos += _add_token(pos+2);
continue;
}
if (pos+2 < offset_end) {
uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
(cpt_next == 'v' && cpt_next_next == 'e') ||
(cpt_next == 'l' && cpt_next_next == 'l')) {
pos += _add_token(pos+3);
continue;
}
}
}
// regex: [^\r\n\p{L}\p{N}]?\p{L}+
if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
pos++;
while (_get_flags(pos).is_letter) {
pos++;
}
_add_token(pos);
continue;
}
}
// regex: \p{N}
if (flags.is_number) {
pos++;
_add_token(pos);
continue;
}
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
pos += (cpt == ' ');
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
flags2 = _get_flags(++pos);
}
uint32_t cpt2 = _get_cpt(pos);
while (cpt2 == '\r' || cpt2 == '\n') {
cpt2 = _get_cpt(++pos);
}
_add_token(pos);
continue;
}
size_t num_whitespaces = 0;
size_t last_end_r_or_n = 0;
while (_get_flags(pos+num_whitespaces).is_whitespace) {
uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
if (cpt2 == '\r' || cpt2 == '\n') {
last_end_r_or_n = pos + num_whitespaces + 1;
}
num_whitespaces++;
}
// regex: \s*[\r\n]+
if (last_end_r_or_n > 0) {
pos = last_end_r_or_n;
_add_token(pos);
continue;
}
// regex: \s+(?!\S)
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
pos += num_whitespaces - 1;
_add_token(pos);
continue;
}
// regex: \s+
if (num_whitespaces > 0) {
pos += num_whitespaces;
_add_token(pos);
continue;
}
// no matches
_add_token(++pos);
}
}
return bpe_offsets;
}
template <typename CharT>
static std::vector<size_t> unicode_regex_split_stl(const std::basic_string<CharT> & text, const std::basic_string<CharT> & regex, const std::vector<size_t> & offsets) {
using BidirIt = typename std::basic_string<CharT>::const_iterator;
@@ -790,8 +925,10 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
} else if (
regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
} else if (
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
bpe_offsets = unicode_regex_split_custom_qwen2(text, offsets);
} else if (regex_expr == "\\p{Han}+") {
// K2's first pattern - handle all K2 patterns together
bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);

View File

@@ -3129,39 +3129,6 @@ struct test_add_id : public test_case {
}
};
// GGML_OP_ADD1
struct test_add1 : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_add1(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 5, 4, 3})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * b = ggml_new_tensor_1d(ctx, type, 1);
// ggml_set_param(b); // TODO: implement
ggml_set_name(b, "b");
ggml_tensor * out = ggml_add1(ctx, a, b);
ggml_set_name(out, "out");
return out;
}
float grad_eps() override {
return 0.1f * ne[0]*ne[1]*ne[2]*ne[3];
}
};
// GGML_OP_SCALE
struct test_scale : public test_case {
const ggml_type type;
@@ -7886,8 +7853,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {2, 2, 2, 2}, 8));
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
test_cases.emplace_back(new test_add1());
test_cases.emplace_back(new test_add1(GGML_TYPE_F32, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_scale());
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -18,7 +18,7 @@
<div style="display: contents">
<script>
{
__sveltekit_1trm5n9 = {
__sveltekit_10avopp = {
base: new URL('.', location).pathname.slice(0, -1)
};

View File

@@ -632,7 +632,7 @@ private:
// load the model and initialize llama_context
// this may also be called to resume from sleeping state
bool load_model(const common_params & params) {
bool load_model(common_params & params) {
bool is_resume = sleeping;
SRV_INF("loading model '%s'\n", params.model.path.c_str());
@@ -641,6 +641,9 @@ private:
llama_init = common_init_from_params(params_base);
// propagate model-metadata sampling defaults back to caller
params.sampling = params_base.sampling;
model = llama_init->model();
ctx = llama_init->context();
@@ -2404,7 +2407,7 @@ private:
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12,
func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold);
return cur.pos_min < pos_min_thold;
return cur.pos_min < pos_min_thold || cur.pos_min == 0;
}
);
@@ -2978,7 +2981,7 @@ private:
server_context::server_context() : impl(new server_context_impl()) {}
server_context::~server_context() = default;
bool server_context::load_model(const common_params & params) {
bool server_context::load_model(common_params & params) {
return impl->load_model(params);
}

View File

@@ -56,7 +56,7 @@ struct server_context {
// load the model and initialize llama_context
// returns true on success
bool load_model(const common_params & params);
bool load_model(common_params & params);
// this function will block main thread until termination
void start_loop();

View File

@@ -51,7 +51,6 @@
"eslint-config-prettier": "^10.0.1",
"eslint-plugin-storybook": "^10.2.4",
"eslint-plugin-svelte": "^3.0.0",
"fflate": "^0.8.2",
"globals": "^16.0.0",
"http-server": "^14.1.1",
"mdast": "^3.0.0",
@@ -5051,13 +5050,6 @@
}
}
},
"node_modules/fflate": {
"version": "0.8.2",
"resolved": "https://registry.npmjs.org/fflate/-/fflate-0.8.2.tgz",
"integrity": "sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==",
"dev": true,
"license": "MIT"
},
"node_modules/file-entry-cache": {
"version": "8.0.0",
"resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-8.0.0.tgz",

View File

@@ -33,7 +33,7 @@
const showToolCallInProgress = $derived(config().showToolCallInProgress as boolean);
const showThoughtInProgress = $derived(config().showThoughtInProgress as boolean);
const sections = $derived(deriveAgenticSections(message, toolMessages, []));
const sections = $derived(deriveAgenticSections(message, toolMessages, [], isStreaming));
// Parse tool results with images
const sectionsParsed = $derived(

View File

@@ -16,6 +16,7 @@
import { rehypeEnhanceLinks } from '$lib/markdown/enhance-links';
import { rehypeEnhanceCodeBlocks } from '$lib/markdown/enhance-code-blocks';
import { rehypeResolveAttachmentImages } from '$lib/markdown/resolve-attachment-images';
import { rehypeRtlSupport } from '$lib/markdown/rehype-rtl-support';
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
import { copyCodeToClipboard, preprocessLaTeX, getImageErrorFallbackHtml } from '$lib/utils';
import {
@@ -101,6 +102,7 @@
.use(rehypeEnhanceLinks) // Add target="_blank" to links
.use(rehypeEnhanceCodeBlocks) // Wrap code blocks with header and actions
.use(rehypeResolveAttachmentImages, { attachments })
.use(rehypeRtlSupport) // Add bidirectional text support
.use(rehypeStringify, { allowDangerousHtml: true }); // Convert to HTML string
});
@@ -781,19 +783,19 @@
/* Lists */
div :global(ul) {
list-style-type: disc;
margin-left: 1.5rem;
margin-inline-start: 1.5rem;
margin-bottom: 1rem;
}
div :global(ol) {
list-style-type: decimal;
margin-left: 1.5rem;
margin-inline-start: 1.5rem;
margin-bottom: 1rem;
}
div :global(li) {
margin-bottom: 0.25rem;
padding-left: 0.5rem;
padding-inline-start: 0.5rem;
}
div :global(li::marker) {
@@ -816,8 +818,8 @@
/* Task lists */
div :global(.task-list-item) {
list-style: none;
margin-left: 0;
padding-left: 0;
margin-inline-start: 0;
padding-inline-start: 0;
}
div :global(.task-list-item-checkbox) {

View File

@@ -0,0 +1,28 @@
/**
* Rehype plugin to provide comprehensive RTL support by adding dir="auto"
* to all text-containing elements.
*
* This operates directly on the HAST tree, ensuring that all elements
* (including those not in a predefined list) receive the attribute.
*/
import type { Plugin } from 'unified';
import type { Root, Element } from 'hast';
import { visit } from 'unist-util-visit';
/**
* Rehype plugin to add dir="auto" to all elements that have children.
* This provides bidirectional text support for mixed RTL/LTR content.
*/
export const rehypeRtlSupport: Plugin<[], Root> = () => {
return (tree: Root) => {
visit(tree, 'element', (node: Element) => {
if (node.children && node.children.length > 0) {
node.properties = {
...node.properties,
dir: 'auto'
};
}
});
};
};

View File

@@ -474,6 +474,7 @@ class AgenticStore {
sessionMessages.push({
role: MessageRole.ASSISTANT,
content: turnContent || undefined,
reasoning_content: turnReasoningContent || undefined,
tool_calls: normalizedCalls
});

View File

@@ -41,6 +41,7 @@ export type AgenticMessage =
| {
role: MessageRole.ASSISTANT;
content?: string | ApiChatMessageContentPart[];
reasoning_content?: string;
tool_calls?: AgenticToolCallPayload[];
}
| {

View File

@@ -38,14 +38,19 @@ export type ToolResultLine = {
function deriveSingleTurnSections(
message: DatabaseMessage,
toolMessages: DatabaseMessage[] = [],
streamingToolCalls: ApiChatCompletionToolCall[] = []
streamingToolCalls: ApiChatCompletionToolCall[] = [],
isStreaming: boolean = false
): AgenticSection[] {
const sections: AgenticSection[] = [];
// 1. Reasoning content (from dedicated field)
if (message.reasoningContent) {
const toolCalls = parseToolCalls(message.toolCalls);
const hasContentAfterReasoning =
!!message.content?.trim() || toolCalls.length > 0 || streamingToolCalls.length > 0;
const isPending = isStreaming && !hasContentAfterReasoning;
sections.push({
type: AgenticSectionType.REASONING,
type: isPending ? AgenticSectionType.REASONING_PENDING : AgenticSectionType.REASONING,
content: message.reasoningContent
});
}
@@ -104,12 +109,13 @@ function deriveSingleTurnSections(
export function deriveAgenticSections(
message: DatabaseMessage,
toolMessages: DatabaseMessage[] = [],
streamingToolCalls: ApiChatCompletionToolCall[] = []
streamingToolCalls: ApiChatCompletionToolCall[] = [],
isStreaming: boolean = false
): AgenticSection[] {
const hasAssistantContinuations = toolMessages.some((m) => m.role === MessageRole.ASSISTANT);
if (!hasAssistantContinuations) {
return deriveSingleTurnSections(message, toolMessages, streamingToolCalls);
return deriveSingleTurnSections(message, toolMessages, streamingToolCalls, isStreaming);
}
const sections: AgenticSection[] = [];
@@ -127,7 +133,12 @@ export function deriveAgenticSections(
const isLastTurn = i + 1 + turnToolMsgs.length >= toolMessages.length;
sections.push(
...deriveSingleTurnSections(msg, turnToolMsgs, isLastTurn ? streamingToolCalls : [])
...deriveSingleTurnSections(
msg,
turnToolMsgs,
isLastTurn ? streamingToolCalls : [],
isLastTurn && isStreaming
)
);
i += 1 + turnToolMsgs.length;

View File

@@ -162,6 +162,36 @@ describe('deriveAgenticSections', () => {
expect(sections[4].content).toBe('Here is the analysis.');
});
it('returns REASONING_PENDING when streaming with only reasoning content', () => {
const msg = makeAssistant({
reasoningContent: 'Let me think about this...'
});
const sections = deriveAgenticSections(msg, [], [], true);
expect(sections).toHaveLength(1);
expect(sections[0].type).toBe(AgenticSectionType.REASONING_PENDING);
expect(sections[0].content).toBe('Let me think about this...');
});
it('returns REASONING (not pending) when streaming but text content has appeared', () => {
const msg = makeAssistant({
content: 'The answer is',
reasoningContent: 'Let me think...'
});
const sections = deriveAgenticSections(msg, [], [], true);
expect(sections).toHaveLength(2);
expect(sections[0].type).toBe(AgenticSectionType.REASONING);
expect(sections[1].type).toBe(AgenticSectionType.TEXT);
});
it('returns REASONING (not pending) when not streaming', () => {
const msg = makeAssistant({
reasoningContent: 'Let me think...'
});
const sections = deriveAgenticSections(msg, [], [], false);
expect(sections).toHaveLength(1);
expect(sections[0].type).toBe(AgenticSectionType.REASONING);
});
it('multi-turn: streaming tool calls on last turn', () => {
const assistant1 = makeAssistant({
toolCalls: JSON.stringify([