Compare commits

..

8 Commits
b9333 ... b9341

Author SHA1 Message Date
ghleg
dbe9c0c8ce convert : support Gemma4ForCausalLM architecture (#23682)
* convert : support Gemma4ForCausalLM architecture (#23674)

* fix indent

---------

Co-authored-by: Oleg Afonin <your.email@example.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-05-26 08:00:31 +03:00
Michael Wand
6fe90deffa models : Attach Mistral3 NVFP4 weight scales (#23629) 2026-05-26 07:59:59 +03:00
Alexey Kopytko
581d020b12 SYCL: implement ggml_sycl_pool_vmm (#22862)
* SYCL: implement ggml_sycl_pool_vmm

* Add an option to bypass VMM with GGML_SYCL_DISABLE_VMM

* Clean up debugging logging

* document GGML_SYCL_DISABLE_VMM

* Multi-stream MoE optimization

* Revert "Multi-stream MoE optimization"

This reverts commit 938929c3f1.

* Update common.hpp

Co-authored-by: Neo Zhang <zhang.jianyu@outlook.com>

* Flip GGML_SYCL_DISABLE_VMM to GGML_SYCL_ENABLE_VMM

* add logging for GGML_SYCL_ENABLE_VMM when extension is not available (SYCL_EXT_ONEAPI_VIRTUAL_MEM macro)

* Apply suggestions from code review

Co-authored-by: Alexey Kopytko <alexey@kopytko.com>

* Apply suggestion from @sanmai

* Apply suggestion from @sanmai

---------

Co-authored-by: Neo Zhang <zhang.jianyu@outlook.com>
2026-05-26 07:59:00 +03:00
Jeff Bolz
7623de11d9 tests: test-backend-ops -j <N> to run tests in parallel (#23637)
Create a pool of N threads that grab a chunk of up to 100 tests at a time to
iterate through. The number of tests at a time decreases as fewer remain.

Each thread uses its own dev and cpu backend, and set_n_threads_fn is not
called on the cpu backend.

Fix some TSAN issues that arose:
- In init_tensor_uniform, don't use static vector of generators.
- Replace gmtime with versions that don't use a global variable.
- Mutex calls to print_test_result.
2026-05-26 07:57:56 +03:00
Niklas Sheth
c9d98295a3 model : add support for talkie-1930-13b (#22596)
* initial talkie support, coherent

* reorder to follow convention

* absorb inverse rope

* stop folding scalars to improve quantization

* use broadcasting instead of duplication

* style cleanup

* add scaling support to LoraTorchTensor; use that path in conversion

* use layer_out_scale instead of embd_skip_scale
2026-05-26 07:57:38 +03:00
Masashi Yoshimura
1506d39e76 ggml-webgpu: Add MMVQ path for Q4/Q8/Q2_K/Q4_K and clean up legacy MUL_MAT pipeline (#23594)
* ggml-webgpu: Add MMVQ path for Q4/Q8/Q2_K/Q4_K

* Fix to editorconfig checking pass

* Remove mul-mat-legacy pipeline

* Fix to use vendor name as is and add dot_product/vendor to shader_lib_ctx
2026-05-25 20:42:49 -07:00
Nikhil Jain
54121f7325 [WebGPU] Check batch_compute_passes before sending passes when not doing GPU profiling (#23457)
* Only run webgpu CI on my fork

* Add webgpu only workflow

* refactor batch_compute_passes to a per-thread variable, and submit individual passes when it is set to false and no GPU profiling is enabled

* restore build.yml
2026-05-25 20:32:49 -07:00
Johannes Gäßler
192d8ae8b8 CUDA: missing PDL sync for FWHT, better fallback (#23690) 2026-05-26 11:05:51 +08:00
31 changed files with 1329 additions and 1032 deletions

View File

@@ -74,6 +74,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"Gemma3nForCausalLM": "gemma",
"Gemma3nForConditionalGeneration": "gemma",
"Gemma4ForConditionalGeneration": "gemma",
"Gemma4ForCausalLM": "gemma",
"GemmaForCausalLM": "gemma",
"Glm4ForCausalLM": "glm",
"Glm4MoeForCausalLM": "glm",
@@ -215,6 +216,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"T5EncoderModel": "t5",
"T5ForConditionalGeneration": "t5",
"T5WithLMHeadModel": "t5",
"TalkieForCausalLM": "talkie",
"UMT5ForConditionalGeneration": "t5",
"UMT5Model": "t5",
"UltravoxModel": "ultravox",

View File

@@ -1622,6 +1622,9 @@ class TextModel(ModelBase):
if chkhsh == "62f6fb0a6fd5098caeabb19b07a5c1099cafc8b9c40eab6ea89ece4ec02fbc57":
# ref: https://huggingface.co/sarvamai/sarvam-30b
res = "sarvam-moe"
if chkhsh == "f728162c1315c26e40249849799b4ba3fe584c32084b4795b03eb295e63cb5af":
# ref: https://huggingface.co/lewtun/talkie-1930-13b-it-hf
res = "talkie"
if res is None:
logger.warning("\n")

View File

@@ -614,7 +614,7 @@ class Gemma3NModel(Gemma3Model):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Gemma4ForConditionalGeneration")
@ModelBase.register("Gemma4ForConditionalGeneration", "Gemma4ForCausalLM")
class Gemma4Model(Gemma3Model):
model_arch = gguf.MODEL_ARCH.GEMMA4

53
conversion/talkie.py Normal file
View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from typing import Iterable, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from torch import Tensor
from .base import LazyTorchTensor, ModelBase, TextModel, gguf
@ModelBase.register("TalkieForCausalLM")
class TalkieModel(TextModel):
model_arch = gguf.MODEL_ARCH.TALKIE
def set_gguf_parameters(self):
super().set_gguf_parameters()
# Talkie used F.rms_norm without an explicit eps
self.gguf_writer.add_layer_norm_rms_eps(torch.finfo(torch.float32).eps)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
prefix = f"model.blocks.{bid}." if bid is not None else ""
suffix = name.removeprefix(prefix)
if suffix == "attn_gain.a_g":
yield self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid, ".scale"), data_torch
return
elif suffix == "mlp_gain.a_g":
yield self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid, ".scale"), data_torch
return
elif suffix == "lm_head_gain.w_g":
self.gguf_writer.add_logit_scale(LazyTorchTensor.to_eager(data_torch).item())
return
elif suffix in ("attn.attn_query.weight", "attn.attn_key.weight"):
# absorb inverse rope
head_dim = self.hparams["head_dim"]
shape = data_torch.shape
data_torch = torch.reshape(data_torch, (-1, head_dim, shape[-1]))
signs = torch.ones((1, head_dim, 1), dtype=data_torch.dtype)
signs[:, head_dim // 2 :, :] = -1
if self.lazy:
signs = LazyTorchTensor.from_eager(signs)
# (n_head, head_dim, n_in) -> (n_out, n_in)
data_torch = torch.reshape(data_torch * signs, shape)
elif suffix == "attn.head_gain.head_g":
# allow head gain to broadcast
data_torch = data_torch.unsqueeze(-1)
if not name.endswith(".weight"):
name += ".weight"
yield from super().modify_tensors(data_torch, name, bid)

View File

@@ -156,6 +156,7 @@ models = [
{"name": "kanana2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601", },
{"name": "f2llmv2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/codefuse-ai/F2LLM-v2-4B", },
{"name": "sarvam-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sarvamai/sarvam-30b", },
{"name": "talkie", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/lewtun/talkie-1930-13b-it-hf", },
]
# some models are known to be broken upstream, so we will skip them as exceptions

View File

@@ -208,6 +208,16 @@ class LoraTorchTensor:
def to(self, *args, **kwargs):
return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
def __mul__(self, other) -> LoraTorchTensor:
# Only output-side multiplication for now
# W = B @ A, so M_out * W == (M_out * B) @ A
if not isinstance(other, (int, float)) and other.shape and other.shape[-1] != 1:
raise NotImplementedError
return LoraTorchTensor(self._lora_A, self._lora_B * other)
def __rmul__(self, other) -> LoraTorchTensor:
return self * other
@classmethod
def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
del types # unused

View File

@@ -743,6 +743,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because SYCL Graph is still on development, no better performance. |
| GGML_SYCL_ENABLE_LEVEL_ZERO | 1 (default) or 0 | Use Level Zero API for device memory allocation instead of SYCL. Reduces system RAM usage on Intel dGPUs by avoiding DMA-buf/TTM host memory staging. Requires GGML_SYCL_SUPPORT_LEVEL_ZERO=ON at build time. |
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
| GGML_SYCL_ENABLE_VMM | 0 or 1 (default) | Enable the virtual-memory device pool. |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Allow SYCL/Unified Runtime Level Zero device allocations larger than 4 GiB. llama.cpp's direct Level Zero allocation path requests the relaxed maximum-size limit itself when GGML_SYCL_ENABLE_LEVEL_ZERO=1. |
@@ -753,6 +754,7 @@ Pass these via `CXXFLAGS` or add a one-off `#define` to enable a flag on the spo
| Name | Function |
|-----------------|----------------------------------------------------------------------------------|
| DEBUG_SYCL_POOL | Enable device memory pool logging on teardown. Useful for profiling allocations. |
| DEBUG_SYCL_MALLOC | Enable verbose per-call logging of device pool alloc/free operations. |
## Design Rule

View File

@@ -19,6 +19,7 @@ __global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows,
float reg[el_w];
const int lane = threadIdx.x;
ggml_cuda_pdl_sync();
#pragma unroll
for (int i = 0; i < el_w; ++i) {
reg[i] = src[i * warp_size + lane] * scale;
@@ -57,10 +58,11 @@ __global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows,
}
}
void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_shape(src, dst));
GGML_ASSERT(ggml_is_contiguous(src));
GGML_ASSERT(ggml_is_contiguous(dst));
if (!ggml_is_contiguous(src) || !ggml_is_contiguous(dst)) {
return false;
}
const int n = src->ne[0];
const int64_t rows = ggml_nrows(src);
@@ -68,7 +70,6 @@ void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src,
float * dst_d = (float *) dst->data;
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
GGML_ASSERT(n % warp_size == 0);
const int rows_per_block = 4;
const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block;
@@ -83,26 +84,18 @@ void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src,
switch (n) {
case 64:
{
ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale);
break;
}
ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale);
return true;
case 128:
{
ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale);
break;
}
ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale);
return true;
case 256:
{
ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale);
break;
}
ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale);
return true;
case 512:
{
ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale);
break;
}
ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale);
return true;
default:
GGML_ABORT("fatal error");
return false;
}
}

View File

@@ -1,3 +1,4 @@
#include "common.cuh"
void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst);
// Returns whether the Fast Walsh-Hadamard transform could be used.
bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst);

View File

@@ -2596,9 +2596,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
const int32_t hint = ggml_get_op_params_i32(dst, 1);
if (hint == GGML_HINT_SRC0_IS_HADAMARD) {
GGML_ASSERT(!split);
ggml_cuda_op_fwht(ctx, src1, dst);
if (hint == GGML_HINT_SRC0_IS_HADAMARD && !split && ggml_cuda_op_fwht(ctx, src1, dst)) {
return;
}

View File

@@ -224,6 +224,7 @@ struct sycl_device_info {
int max_wg_per_cu; // max work groups per compute unit - refer to
// cudaOccupancyMaxActiveBlocksPerMultiprocessor
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
size_t total_vram;
sycl_hw_info hw_info;
optimize_feature opt_feature;
@@ -244,6 +245,8 @@ struct ggml_sycl_device_info {
const ggml_sycl_device_info & ggml_sycl_info();
static constexpr size_t SYCL_BUFFER_ALIGNMENT = 128;
struct ggml_sycl_pool {
virtual ~ggml_sycl_pool() = default;

View File

@@ -19,6 +19,7 @@
#include <cstdlib>
#include <float.h>
#include <limits>
#include <optional>
#include <stdint.h>
#include <stdio.h>
#include <vector>
@@ -37,6 +38,11 @@
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
# include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
#endif
#if SYCL_EXT_ONEAPI_VIRTUAL_MEM
# include <sycl/ext/oneapi/virtual_mem/physical_mem.hpp>
# include <sycl/ext/oneapi/virtual_mem/virtual_mem.hpp>
# define GGML_SYCL_USE_VMM
#endif
#include <sycl/half_type.hpp>
#include "ggml.h"
@@ -70,6 +76,7 @@ int g_ggml_sycl_debug = 0;
int g_ggml_sycl_disable_optimize = 0;
int g_ggml_sycl_disable_graph = 0;
int g_ggml_sycl_disable_dnn = 0;
int g_ggml_sycl_enable_vmm = 1;
int g_ggml_sycl_prioritize_dmmv = 0;
int g_ggml_sycl_use_async_mem_op = 0;
int g_ggml_sycl_use_async_mem_op_requested = 1;
@@ -96,13 +103,30 @@ static ggml_sycl_device_info ggml_sycl_init() {
// GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
// #endif
for (int i = 0; i < info.device_count; ++i) {
info.devices[i].vmm = 0;
dpct::device_info prop;
auto & device = dpct::dev_mgr::instance().get_device(i);
SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
prop, device)));
#if !defined(GGML_SYCL_USE_VMM)
info.devices[i].vmm = 0;
#else
info.devices[i].vmm = device.has(sycl::aspect::ext_oneapi_virtual_mem);
if (info.devices[i].vmm) {
// NB: SYCL's get_mem_granularity always returns the _minimum_ granularity,
// but the L0 API requires a larger page size for allocs above 2 MiB and
// rejects non-multiples with UR_RESULT_ERROR_INVALID_VALUE [sic].
// Here we clamp it to 2 MiB for simplicity, but other devices may require
// calling zeVirtualMemQueryPageSize or yet unexposed public API.
const size_t physical_page = 2ull << 20; // 2 MiB
info.devices[i].vmm_granularity = std::max<size_t>(
sycl::ext::oneapi::experimental::get_mem_granularity(
device, sycl::context(device)),
physical_page);
}
#endif
info.default_tensor_split[i] = total_vram;
total_vram += prop.get_global_mem_size();
@@ -234,6 +258,7 @@ static void ggml_check_sycl() try {
g_ggml_sycl_disable_optimize = get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
g_ggml_sycl_enable_vmm = get_sycl_env("GGML_SYCL_ENABLE_VMM", 1);
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO
g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", ggml_sycl_info().ext_oneapi_level_zero);
@@ -275,6 +300,11 @@ static void ggml_check_sycl() try {
#else
GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: no\n");
#endif
#if defined(GGML_SYCL_USE_VMM)
GGML_LOG_INFO(" GGML_SYCL_USE_VMM: yes\n");
#else
GGML_LOG_INFO(" GGML_SYCL_USE_VMM: no\n");
#endif
GGML_LOG_INFO("Running with Environment Variables:\n");
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
@@ -293,6 +323,11 @@ static void ggml_check_sycl() try {
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
#else
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
#endif
#if defined(GGML_SYCL_USE_VMM)
GGML_LOG_INFO(" GGML_SYCL_ENABLE_VMM: %d\n", g_ggml_sycl_enable_vmm);
#else
GGML_LOG_INFO(" GGML_SYCL_ENABLE_VMM: virtual memory extension is not available\n");
#endif
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
g_ggml_sycl_use_async_mem_op_requested = get_sycl_env("GGML_SYCL_USE_ASYNC_MEM_OP", 1);
@@ -754,7 +789,7 @@ catch (sycl::exception const &exc) {
}
static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return 128;
return SYCL_BUFFER_ALIGNMENT;
GGML_UNUSED(buft);
}
@@ -1177,7 +1212,7 @@ static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(gg
}
static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return 128;
return SYCL_BUFFER_ALIGNMENT;
GGML_UNUSED(buft);
}
@@ -1462,6 +1497,121 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
}
};
// pool with virtual memory management
#if defined(GGML_SYCL_USE_VMM)
struct ggml_sycl_pool_vmm : public ggml_sycl_pool {
static const size_t SYCL_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
int device;
sycl::context ctx;
sycl::device dev;
uintptr_t pool_addr = 0;
size_t pool_used = 0;
size_t pool_size = 0;
size_t granularity;
// physical_mem owns the commits (unlike cuMemMap)
struct mapping {
sycl::ext::oneapi::experimental::physical_mem phys;
void * map_ptr;
};
std::vector<mapping> mappings;
explicit ggml_sycl_pool_vmm(queue_ptr qptr_, int device_) :
device(device_),
ctx(qptr_->get_context()),
dev(qptr_->get_device()),
granularity(ggml_sycl_info().devices[device_].vmm_granularity) {
}
~ggml_sycl_pool_vmm() {
if (pool_addr == 0) {
return;
}
// Per spec, unmap must (a) match the exact (ptr, size) of an earlier
// physical_mem::map() call and (b) precede destruction of the
// physical_mem objects (their dtors won't unmap).
for (auto & m : mappings) {
SYCL_CHECK(CHECK_TRY_ERROR(sycl::ext::oneapi::experimental::unmap(
m.map_ptr, m.phys.size(), ctx)));
}
SYCL_CHECK(CHECK_TRY_ERROR(sycl::ext::oneapi::experimental::free_virtual_mem(
pool_addr, SYCL_POOL_VMM_MAX_SIZE, ctx)));
}
void * alloc(size_t size, size_t * actual_size) override {
// round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
size = GGML_PAD(size, SYCL_BUFFER_ALIGNMENT);
size_t avail = pool_size - pool_used;
if (size > avail) {
// round up to the next multiple of the granularity
size_t reserve_size = GGML_PAD(size - avail, granularity);
GGML_ASSERT(pool_size + reserve_size <= SYCL_POOL_VMM_MAX_SIZE);
// allocate more physical memory
std::optional<sycl::ext::oneapi::experimental::physical_mem> phys;
SYCL_CHECK(CHECK_TRY_ERROR(phys.emplace(dev, ctx, reserve_size)));
// reserve virtual address space (if not already reserved)
if (pool_addr == 0) {
SYCL_CHECK(CHECK_TRY_ERROR(
pool_addr = sycl::ext::oneapi::experimental::reserve_virtual_mem(
SYCL_POOL_VMM_MAX_SIZE, ctx)));
}
// map at the end of the pool
void * map_ptr = nullptr;
SYCL_CHECK(CHECK_TRY_ERROR(
map_ptr = phys->map(pool_addr + pool_size, reserve_size,
sycl::ext::oneapi::experimental::address_access_mode::read_write)));
// stash these so we could unmap this exact range in dtor
mappings.push_back({
std::move(*phys),
map_ptr,
});
// add to the pool
pool_size += reserve_size;
#ifdef DEBUG_SYCL_MALLOC
GGML_LOG_INFO("sycl pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
device, (unsigned long long) (pool_size/1024/1024),
(unsigned long long) (reserve_size/1024/1024));
#endif
}
GGML_ASSERT(pool_addr != 0);
void * ptr = reinterpret_cast<void *>(pool_addr + pool_used);
*actual_size = size;
pool_used += size;
#ifdef DEBUG_SYCL_MALLOC
GGML_LOG_INFO("sycl pool[%d]: allocated %llu bytes at %p\n", device, (unsigned long long) size, ptr);
#endif
return ptr;
}
void free(void * ptr, size_t size) override {
#ifdef DEBUG_SYCL_MALLOC
GGML_LOG_INFO("sycl pool[%d]: freed %llu bytes at %p\n", device, (unsigned long long) size, ptr);
#endif
pool_used -= size;
// all deallocations must be in reverse order of the allocations
GGML_ASSERT(ptr == reinterpret_cast<void *>(pool_addr + pool_used));
}
};
#endif // defined(GGML_SYCL_USE_VMM)
struct ggml_sycl_pool_host : public ggml_sycl_pool {
queue_ptr qptr;
int device;
@@ -1542,20 +1692,19 @@ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(que
}
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
// TBD: NO VMM support
// if (ggml_sycl_info().devices[device].vmm) {
// return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(device));
// }
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device));
#if defined(GGML_SYCL_USE_VMM)
if (g_ggml_sycl_enable_vmm && ggml_sycl_info().devices[device].vmm) {
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(qptr, device));
}
#endif // defined(GGML_SYCL_USE_VMM)
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device));
}
std::unique_ptr<ggml_sycl_fattn_kv_buffers> ggml_backend_sycl_context::new_fattn_kv_buffers(queue_ptr qptr, int device) {
return std::unique_ptr<ggml_sycl_fattn_kv_buffers>(new ggml_sycl_fattn_kv_buffers(qptr, device));
}
// TBD pool with virtual memory management
// struct ggml_sycl_pool_vmm : public ggml_sycl_pool
/// kernels
typedef void (*ggml_sycl_op_mul_mat_t)(
ggml_backend_sycl_context & ctx,

View File

@@ -52,7 +52,7 @@
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4
// default size for legacy matrix multiplication
// default size for reg-tile matrix multiplication
#define WEBGPU_MUL_MAT_WG_SIZE 256
// Same hash combine function as in boost
@@ -93,6 +93,8 @@ struct ggml_webgpu_shader_lib_context {
uint32_t sg_mat_k = 0;
uint32_t min_subgroup_size = 0;
uint32_t max_subgroup_size = 0;
bool supports_dot_product = false;
std::string vendor;
};
struct webgpu_pipeline {
@@ -850,31 +852,15 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
/** Matrix Multiplication **/
struct ggml_webgpu_legacy_mul_mat_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type;
}
};
struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash {
size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
return seed;
}
};
struct ggml_webgpu_mul_mat_vec_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
int vectorized;
bool use_mmvq;
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized;
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
use_mmvq == other.use_mmvq;
}
};
@@ -884,6 +870,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
ggml_webgpu_hash_combine(seed, key.vectorized);
ggml_webgpu_hash_combine(seed, key.use_mmvq);
return seed;
}
};
@@ -894,6 +881,20 @@ struct ggml_webgpu_mul_mat_vec_shader_decisions {
uint32_t vec_size;
};
struct ggml_webgpu_quantize_q8_pipeline_key {
ggml_type src0_type;
bool operator==(const ggml_webgpu_quantize_q8_pipeline_key & other) const { return src0_type == other.src0_type; }
};
struct ggml_webgpu_quantize_q8_pipeline_key_hash {
size_t operator()(const ggml_webgpu_quantize_q8_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.src0_type);
return seed;
}
};
struct ggml_webgpu_mul_mat_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
@@ -1051,6 +1052,36 @@ struct ggml_webgpu_soft_max_pipeline_key_hash {
}
};
/** MMVQ **/
inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
const ggml_tensor * src1,
bool supports_dot_product,
const std::string & vendor) {
if (src1->ne[1] == 1) {
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
if (supports_dp4a && supports_dot_product) {
switch (src1->type) {
case GGML_TYPE_F32:
switch (src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q4_K:
return src0->ne[0] % 4 == 0;
default:
break;
}
break;
default:
break;
}
}
}
return false;
}
class ggml_webgpu_shader_lib {
wgpu::Device device;
pre_wgsl::Preprocessor preprocessor;
@@ -1099,14 +1130,12 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline,
ggml_webgpu_flash_attn_blk_pipeline_key_hash>
flash_attn_blk_pipelines;
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
webgpu_pipeline,
ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec)
std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
mul_mat_vec_pipelines; // fast mat-vec (n==1)
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash>
quantize_q8_pipelines;
std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
mul_mat_id_pipelines; // src0_type/src1_type
@@ -1631,7 +1660,7 @@ class ggml_webgpu_shader_lib {
key.type = context.dst->type;
key.d_state = (int) context.src0->ne[0];
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
ggml_webgpu_tensor_overlap(context.src1, context.src5);
ggml_webgpu_tensor_overlap(context.src1, context.src5);
auto it = ssm_scan_pipelines.find(key);
if (it != ssm_scan_pipelines.end()) {
@@ -1744,6 +1773,44 @@ class ggml_webgpu_shader_lib {
return pad_pipelines[key];
}
webgpu_pipeline get_quantize_q8_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_quantize_q8_pipeline_key key = {};
key.src0_type = context.src0->type;
auto it = quantize_q8_pipelines.find(key);
if (it != quantize_q8_pipelines.end()) {
return it->second;
}
const char * shader_src = wgsl_quantize_q8;
std::vector<std::string> defines;
std::string variant = "quantize_q8";
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
defines.push_back("SRC1_INNER_TYPE=f32");
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
std::string src0_name = src0_traits->type_name;
std::string type_upper = src0_name;
variant += "_" + src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
defines.push_back("MUL_ACC_" + type_upper);
defines.push_back("Q8_1_T");
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
auto processed = preprocessor.preprocess(shader_src, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
quantize_q8_pipelines[key] = pipeline;
return quantize_q8_pipelines[key];
}
webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
key.src0_type = context.src0->type;
@@ -1752,6 +1819,8 @@ class ggml_webgpu_shader_lib {
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.use_mmvq =
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
auto it = mul_mat_vec_pipelines.find(key);
if (it != mul_mat_vec_pipelines.end()) {
@@ -1788,6 +1857,19 @@ class ggml_webgpu_shader_lib {
defines.push_back("U32_DEQUANT_HELPERS");
defines.push_back("SRC0_INNER_TYPE=u32");
switch (context.src0->type) {
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
if (key.use_mmvq) {
defines.push_back("LEGACY_QUANTS");
}
break;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q4_K:
if (key.use_mmvq) {
defines.push_back("K_QUANTS");
}
break;
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ2_S:
@@ -1840,6 +1922,11 @@ class ggml_webgpu_shader_lib {
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
}
if (key.use_mmvq) {
defines.push_back("MMVQ");
defines.push_back("Q8_1_T");
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
@@ -2018,100 +2105,6 @@ class ggml_webgpu_shader_lib {
return mul_mat_fast_pipelines[key];
}
webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_legacy_mul_mat_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
auto it = mul_mat_legacy_pipelines.find(key);
if (it != mul_mat_legacy_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "mul_mat";
switch (context.src1->type) {
case GGML_TYPE_F32:
defines.push_back("SRC1_TYPE=f32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC1_TYPE=f16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported src1 type for mul_mat legacy shader");
}
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
const char * src0_name = src0_traits->type_name;
switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back("SRC0_TYPE=f32");
defines.push_back("FLOAT");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC0_TYPE=f16");
defines.push_back("FLOAT");
variant += "_f16";
break;
default:
{
std::string type_upper = src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
switch (context.src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
{
// Quantized types using u32 buffers for portability.
defines.push_back("SRC0_TYPE=u32");
defines.push_back("U32_DEQUANT_HELPERS");
break;
}
default:
{
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
}
}
defines.push_back("BYTE_HELPERS");
defines.push_back(type_upper + "_T");
defines.push_back(type_upper);
defines.push_back(type_upper + "_SCALE_MIN");
defines.push_back(type_upper + "_TABLES");
defines.push_back(type_upper + "_GRID");
variant += std::string("_") + src0_name;
break;
}
}
auto processed = preprocessor.preprocess(wgsl_mul_mat, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
mul_mat_legacy_pipelines[key] = pipeline;
return mul_mat_legacy_pipelines[key];
}
webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) {
auto it = mul_mat_id_gather_pipelines.find(1);
if (it != mul_mat_id_gather_pipelines.end()) {

View File

@@ -181,6 +181,7 @@ struct webgpu_capabilities {
wgpu::Limits limits;
bool supports_subgroups = false;
bool supports_subgroup_matrix = false;
bool supports_dot_product = false;
uint32_t sg_mat_m = 0;
uint32_t sg_mat_n = 0;
@@ -210,6 +211,8 @@ struct webgpu_global_context_struct {
wgpu::Buffer memset_params_buf;
webgpu_pipeline memset_pipeline;
std::string vendor;
// TODO: We should rework the CPU profiling time handling to make it more useful. ref: https://github.com/ggml-org/llama.cpp/pull/22050
#ifdef GGML_WEBGPU_CPU_PROFILE
// Profiling: labeled CPU time in ms (total)
@@ -259,6 +262,7 @@ struct webgpu_context_struct {
wgpu::Buffer set_rows_host_error_buf;
wgpu::CommandEncoder active_command_encoder;
wgpu::ComputePassEncoder active_compute_pass;
bool batch_compute_passes = true;
size_t memset_bytes_per_thread;
@@ -590,9 +594,18 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context &
}
#else
for (size_t i = 0; i < dispatches.size(); i++) {
ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline);
ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]);
ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1);
if (ctx->batch_compute_passes) {
ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline);
ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]);
ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second,
1);
} else {
wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass();
pass.SetPipeline(dispatches[i].pipeline.pipeline);
pass.SetBindGroup(0, bind_groups[i]);
pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1);
pass.End();
}
}
#endif
@@ -1384,6 +1397,58 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx,
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
}
static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst,
std::vector<webgpu_dispatch_desc> & dispatches) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
webgpu_pipeline qq8_pipeline = ctx->shader_lib->get_quantize_q8_pipeline(shader_lib_ctx);
// quantize_q8 pipeline
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
const size_t q8_src1_align_offset = ROUNDUP_POW2(
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
const size_t q8_src1_binding_size =
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
WEBGPU_STORAGE_BUF_BINDING_MULT);
std::vector<uint32_t> q8_params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) src1->ne[0],
(uint32_t) src1->ne[2],
(uint32_t) src1->ne[3],
};
std::vector<wgpu::BindGroupEntry> q8_entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1),
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), q8_src1_align_offset, q8_src1_binding_size)
};
auto q8_decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(qq8_pipeline.context.get());
uint32_t q8_wg_size = q8_decisions->wg_size;
uint32_t q8_wg_x = 1;
uint32_t q8_wg_y = 1;
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
dispatches.push_back({
qq8_pipeline, std::move(q8_params), std::move(q8_entries), { q8_wg_x, q8_wg_y }
});
}
static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
@@ -1391,47 +1456,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
// Determine if this is a mat-vec operation
bool is_vec = (dst->ne[1] == 1);
// Determine if we should use fast path
bool use_fast = false;
switch (src1->type) {
case GGML_TYPE_F16:
use_fast = (src0->type == GGML_TYPE_F16);
break;
case GGML_TYPE_F32:
// TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
switch (src0->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q1_0:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_MXFP4:
use_fast = true;
break;
default:
break;
}
break;
default:
break;
}
// use MMVQ path for mat-vec
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
ctx->global_ctx->vendor);
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
@@ -1446,16 +1473,20 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
shader_lib_ctx.supports_dot_product = ctx->global_ctx->capabilities.supports_dot_product;
shader_lib_ctx.vendor = ctx->global_ctx->vendor;
// Get or create pipeline
webgpu_pipeline pipeline;
webgpu_pipeline pipeline;
std::vector<webgpu_dispatch_desc> dispatches;
if (use_fast && is_vec) {
if (is_vec) {
if (use_mmvq) {
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
}
pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx);
} else if (use_fast) {
pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);
} else {
pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx);
pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);
}
// Build params
@@ -1479,25 +1510,31 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
};
// Build bind group entries
std::vector<wgpu::BindGroupEntry> entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
};
std::vector<wgpu::BindGroupEntry> entries = {};
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0));
if (use_mmvq) {
auto & mmvq_qq8_entry = dispatches[0].bind_group_entries[1];
entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), mmvq_qq8_entry.offset,
mmvq_qq8_entry.size));
} else {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
}
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
// Calculate workgroup dimensions
uint32_t wg_x = 1;
uint32_t wg_y = 1;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
if (use_fast && is_vec) {
if (is_vec) {
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
uint32_t batches = dst->ne[2] * dst->ne[3];
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
uint32_t total_wg = output_groups * batches;
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
} else if (use_fast) {
} else {
auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
// Fast-path tiled/subgroup calculations
@@ -1518,15 +1555,13 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
}
uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3];
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
} else { // legacy
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_size = decisions->wg_size;
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
}
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
dispatches.push_back({
pipeline, std::move(params), std::move(entries), { wg_x, wg_y }
});
return ggml_backend_webgpu_build_multi(ctx, dispatches);
}
static webgpu_encoded_op ggml_webgpu_mul_mat_id_vec(webgpu_context & ctx,
@@ -1956,10 +1991,10 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
std::vector<wgpu::BindGroupEntry> reduce_entries;
if (use_vec_reduce) {
const uint32_t reduce_sg_size = ctx->global_ctx->capabilities.max_subgroup_size;
const uint32_t reduce_wg_size =
std::max(reduce_sg_size, (uint32_t) std::min<uint64_t>(
(uint64_t) nwg * reduce_sg_size,
ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
const uint32_t reduce_wg_size = std::max(
reduce_sg_size,
(uint32_t) std::min<uint64_t>((uint64_t) nwg * reduce_sg_size,
ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx;
reduce_shader_ctx.max_wg_size = reduce_wg_size;
reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx);
@@ -3110,18 +3145,16 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
uint32_t num_batched_kernels = 0;
uint32_t num_inflight_batches = 0;
bool contains_set_rows = false;
bool batch_compute_passes = true;
int num_encoded_ops = 1;
int node_idx = 0;
#ifdef GGML_WEBGPU_GPU_PROFILE
ctx->profile_timestamp_query_count = 0;
batch_compute_passes = false;
std::vector<std::string> profile_pipeline_names;
#endif
ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder();
if (batch_compute_passes) {
if (ctx->batch_compute_passes) {
ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass();
}
@@ -3148,7 +3181,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
// reset state for next batch
ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder();
if (batch_compute_passes) {
if (ctx->batch_compute_passes) {
ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass();
}
ctx->param_arena.reset();
@@ -3548,8 +3581,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
const uint32_t kv_tile = decisions.kv_tile;
const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
uint32_t nwg = 1u;
const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
uint32_t nwg = 1u;
const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
nwg <<= 1;
}
@@ -3582,6 +3615,22 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
}
}
break;
case GGML_OP_MUL_MAT:
{
const ggml_tensor * src0 = tensor->src[0];
const ggml_tensor * src1 = tensor->src[1];
bool use_mmvq =
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
ctx->webgpu_global_ctx->vendor);
if (use_mmvq) {
const size_t q8_src1_size =
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
res = ROUNDUP_POW2(res + q8_src1_size +
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
WEBGPU_STORAGE_BUF_BINDING_MULT);
}
}
break;
case GGML_OP_MUL_MAT_ID:
{
const ggml_tensor * src0 = tensor->src[0];
@@ -3707,12 +3756,16 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
ctx->webgpu_global_ctx->adapter.GetInfo(&info);
ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size();
ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches();
ctx->webgpu_global_ctx->vendor = info.vendor;
wgpu::SupportedFeatures features;
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
// we require f16 support
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
// for dot4I8packed
ctx->webgpu_global_ctx->capabilities.supports_dot_product = ctx->webgpu_global_ctx->instance.HasWGSLLanguageFeature(
wgpu::WGSLLanguageFeatureName::Packed4x8IntegerDotProduct);
bool valid_subgroup_matrix_config = false;
#ifndef __EMSCRIPTEN__
@@ -3839,6 +3892,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
#ifdef GGML_WEBGPU_GPU_PROFILE
webgpu_ctx->batch_compute_passes = false;
ggml_webgpu_create_buffer(
webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_dev_buf, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, "profile_timestamp_dev_buf");

View File

@@ -95,11 +95,10 @@ struct q5_1 {
};
#endif
#ifdef Q8_1_T
struct q8_1 {
d: f16,
m: f16,
s: f16, // d * sum(qs[i])
qs: array<u32, 8>
};
#endif

View File

@@ -1,747 +0,0 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#ifdef FLOAT
const BLOCK_SIZE = 1u;
#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL)
const BLOCK_SIZE = 32u;
#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS)
const BLOCK_SIZE = 256u;
#endif
#ifdef FLOAT
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
}
#endif
#ifdef Q4_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at_src0(block_byte_base);
var sum: f32 = 0.0;
for (var j: u32 = 0; j < 4; j++) {
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
sum += q_lo * f32(src1[src1_offset]);
sum += q_hi * f32(src1[src1_offset + 16]);
}
}
return sum;
}
#endif
#ifdef Q4_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q4_1 = src0[src0_idx_base + offset];
let d = f32(block_q4_1.d);
let m = f32(block_q4_1.m);
var sum: f32 = 0.0;
for (var j: u32 = 0; j < 4; j++) {
let q_packed = block_q4_1.qs[j];
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
let q_lo = f32(q_byte & 0xF) * d + m;
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
sum += q_lo * f32(src1[src1_offset]);
sum += q_hi * f32(src1[src1_offset + 16]);
}
}
return sum;
}
#endif
#ifdef Q5_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes
let d = load_f16_as_f32_at_src0(block_byte_base);
var sum: f32 = 0.0;
let qh_packed = load_u32_at_src0(block_byte_base + 2);
for (var j: u32 = 0; j < 4; j++) {
let q_byte_offset = block_byte_base + 6 + j * 4;
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
sum += q_lo * f32(src1[src1_offset]);
sum += q_hi * f32(src1[src1_offset + 16]);
}
}
return sum;
}
#endif
#ifdef Q5_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q5_1 = src0[src0_idx_base + offset];
let d = f32(block_q5_1.d);
let m = f32(block_q5_1.m);
var sum: f32 = 0.0;
for (var j: u32 = 0; j < 4; j++) {
let q_packed = block_q5_1.qs[j];
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
sum += q_lo * f32(src1[src1_offset]);
sum += q_hi * f32(src1[src1_offset + 16]);
}
}
return sum;
}
#endif
#ifdef Q8_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes
let d = load_f16_as_f32_at_src0(block_byte_base);
var sum: f32 = 0.0;
for (var j: u32 = 0; j < 8; j++) {
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
sum += q_val * f32(src1[src1_offset]);
}
}
return sum;
}
#endif
#ifdef Q8_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q8_1 = src0[src0_idx_base + offset];
let d = f32(block_q8_1.d);
let m = f32(block_q8_1.m);
var sum: f32 = 0.0;
for (var j: u32 = 0; j < 8; j++) {
let q_packed = block_q8_1.qs[j];
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d + m;
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
sum += q_val * f32(src1[src1_offset]);
}
}
return sum;
}
#endif
#ifdef Q2_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
let m = f32(block.dmin);
var sum = 0.0;
var src1_i = src1_idx_base + offset * 256;
var is: u32 = 0;
// 2 halves of the block (128 elements each)
for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
// 4 groups (each group has 2 blocks of 16 elements)
for (var shift: u32 = 0; shift < 8; shift += 2) {
// 2 blocks
for (var k: u32 = 0; k < 32; k += 16) {
let sc = get_byte(block.scales[is / 4], is % 4);
is++;
let dl = d * f32(sc & 0xF);
let ml = m * f32(sc >> 4);
for (var l: u32 = 0u; l < 16; l++) {
let q_idx = q_b_idx + k + l;
let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
let qs_val = (q_byte >> shift) & 3;
sum += (f32(qs_val) * dl - ml) * src1[src1_i];
src1_i++;
}
}
}
}
return sum;
}
#endif
#ifdef Q3_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
// Bytes 108-109: f16 scale 'd'
let d = load_f16_as_f32_at_src0(block_byte_base + 108);
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
// and 2-bits from the last 4 bytes
// Bytes 96-107: 12 bytes of scales (3 u32s)
let kmask1: u32 = 0x03030303;
let kmask2: u32 = 0x0f0f0f0f;
var scale_vals: array<u32, 4>;
scale_vals[0] = load_u32_at_src0(block_byte_base + 96);
scale_vals[1] = load_u32_at_src0(block_byte_base + 100);
scale_vals[2] = load_u32_at_src0(block_byte_base + 104);
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
// Bytes 0-31: 32 bytes of hmask (8 u32s)
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4);
}
// Bytes 32-95: 64 bytes of qs (16 u32s)
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16; i++) {
qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4);
}
var sum = 0.0;
var src1_i = src1_idx_base + offset * 256;
var is: u32 = 0;
var m: u32 = 1;
// 2 halves of the block (128 elements each)
for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
// 4 groups (each group has 2 blocks of 16 elements)
for (var shift: u32 = 0; shift < 8; shift += 2) {
// 2 blocks
for (var k: u32 = 0; k < 32; k += 16) {
let sc = get_byte(scale_vals[is / 4], is % 4);
is++;
let dl = d * (f32(sc) - 32.0);
for (var l: u32 = 0u; l < 16u; l++) {
let q_idx = q_b_idx + k + l;
let hm_idx = k + l;
let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
let qs_val = (q_byte >> shift) & 3;
sum += ((f32(qs_val) - hm) * dl) * src1[src1_i];
src1_i++;
}
}
m <<= 1;
}
}
return sum;
}
#endif
#ifdef Q4_K
// 8 blocks of 32 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
let m = f32(block.dmin);
var sum = 0.0;
var src1_i = src1_idx_base + offset * 256;
var is: u32 = 0;
// 2 blocks each iteration
for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
for (var shift: u32 = 0; shift < 8; shift += 4) {
let scale_min = get_scale_min(is, block.scales);
is++;
let dl = d * scale_min.x;
let ml = m * scale_min.y;
for (var l: u32 = 0; l < 32; l++) {
let q_idx = q_b_idx + l;
let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
let qs_val = (q_byte >> shift) & 0xF;
sum += (f32(qs_val) * dl - ml) * src1[src1_i];
src1_i++;
}
}
}
return sum;
}
#endif
#ifdef Q5_K
// 8 blocks of 32 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
let m = f32(block.dmin);
var sum = 0.0;
var src1_i = src1_idx_base + offset * 256;
var is: u32 = 0;
var u: u32 = 1;
// 2 blocks each iteration
for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
for (var shift: u32 = 0; shift < 8; shift += 4) {
let scale_min = get_scale_min(is, block.scales);
is++;
let dl = d * scale_min.x;
let ml = m * scale_min.y;
for (var l: u32 = 0; l < 32; l++) {
let q_idx = q_b_idx + l;
let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
let qh_byte = get_byte(block.qh[l / 4], l % 4);
let qs_val = (q_byte >> shift) & 0xF;
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i];
src1_i++;
}
u <<= 1;
}
}
return sum;
}
#endif
#ifdef Q6_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes
// Bytes 208-209: f16 scale 'd'
let d = load_f16_as_f32_at_src0(block_byte_base + 208);
// Bytes 0-127: 128 bytes of ql (32 u32s)
var ql_vals: array<u32, 32>;
for (var i: u32 = 0; i < 32; i++) {
ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4);
}
// Bytes 128-191: 64 bytes of qh (16 u32s)
var qh_vals: array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4);
}
// Bytes 192-207: 16 bytes of scales (4 u32s)
var scale_vals: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4);
}
var sum = 0.0;
var src1_i = src1_idx_base + offset * 256;
var qh_b_idx: u32 = 0;
var sc_b_idx: u32 = 0;
for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
for (var l: u32 = 0; l < 32; l++) {
let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);
let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);
let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);
let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;
let is = l/16;
let is1 = sc_b_idx + is;
let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);
let is2 = sc_b_idx + is + 2;
let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);
let is3 = sc_b_idx + is + 4;
let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);
let is4 = sc_b_idx + is + 6;
let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);
sum += d * f32(sc1) * q1 * src1[src1_i + l];
sum += d * f32(sc2) * q2 * src1[src1_i + l + 32];
sum += d * f32(sc3) * q3 * src1[src1_i + l + 64];
sum += d * f32(sc4) * q4 * src1[src1_i + l + 96];
}
src1_i += 128;
qh_b_idx += 32;
sc_b_idx += 8;
}
return sum;
}
#endif
#ifdef IQ2_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 32; ib += 4) {
let aux0_offset = block_byte_base + 2 + ib * 2;
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
let aux0 = load_u32_at_src0(aux0_offset);
let aux1 = load_u32_at_src0(aux1_offset);
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
for (var l: u32 = 0; l < 4; l++) {
let ig = get_byte(aux0, l) * 8;
let is = (aux1 >> (7 * l)) & 127;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
for (var j: u32 = 0; j < 8; j++) {
let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
sum += db * f32(g) * m * src1[src1_i];
src1_i++;
}
}
}
return sum;
}
#endif
#ifdef IQ2_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var scale_vals = array<u32, 2>(
load_u32_at_src0(block_byte_base + 66),
load_u32_at_src0(block_byte_base + 70)
);
var sum = 0.0;
for (var ib: u32 = 0; ib < 32; ib += 4) {
let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
let db = array<f32, 2>(
d * (0.5 + f32(s & 0xF)) * 0.25,
d * (0.5 + f32(s >> 4)) * 0.25
);
for (var l: u32 = 0; l < 4; l++) {
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF;
let ig = (qs_val & 511) * 8;
let is = qs_val >> 9;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
let dl = db[l/2];
for (var j: u32 = 0; j < 8; j++) {
let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
sum += dl * f32(g) * m * src1[src1_i];
src1_i++;
}
}
}
return sum;
}
#endif
#ifdef IQ2_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var qs_vals : array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4);
}
var qh_vals: array<u32, 2>;
qh_vals[0] = load_u32_at_src0(block_byte_base + 66);
qh_vals[1] = load_u32_at_src0(block_byte_base + 70);
var scale_vals: array<u32, 2>;
scale_vals[0] = load_u32_at_src0(block_byte_base + 74);
scale_vals[1] = load_u32_at_src0(block_byte_base + 78);
var sum = 0.0;
for (var ib: u32 = 0; ib < 8; ib ++) {
let s = get_byte(scale_vals[ib / 4], ib % 4);
let db = array<f32, 2>(
d * (0.5 + f32(s & 0xF)) * 0.25,
d * (0.5 + f32(s >> 4)) * 0.25
);
let qs_w = qs_vals[ib];
for (var l: u32 = 0; l < 4; l++) {
let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;
let ig = (get_byte(qs_w, l) | qh_b) * 8;
let signs = get_byte(qs_vals[ib + 8], l);
let dl = db[l/2];
for (var j: u32 = 0; j < 8; j++) {
let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
sum += dl * f32(g) * m * src1[src1_i];
src1_i++;
}
}
}
return sum;
}
#endif
#ifdef IQ3_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 16; ib += 2) {
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
let sc_sign = load_u32_at_src0(sc_sign_offset);
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
for (var l: u32 = 0; l < 4; l++) {
let is = (sc_sign >> (7 * l)) & 127;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0);
let ig2 = get_byte(ig_val, 1);
for (var j: u32 = 0; j < 4; j++) {
let g1 = get_byte(iq3xxs_grid[ig1], j);
let g2 = get_byte(iq3xxs_grid[ig2], j);
let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
sum += db * f32(g1) * m1 * src1[src1_i];
sum += db * f32(g2) * m2 * src1[src1_i + 4];
src1_i++;
}
src1_i += 4;
}
}
return sum;
}
#endif
#ifdef IQ3_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var qh_vals = array<u32, 2>(
load_u32_at_src0(block_byte_base + 66),
load_u32_at_src0(block_byte_base + 70)
);
var sign_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4);
}
var scale_vals = load_u32_at_src0(block_byte_base + 106);
var sum = 0.0;
for (var ib: u32 = 0; ib < 4; ib++) {
let s = get_byte(scale_vals, ib);
let db = array<f32, 2>(
d * (1.0 + 2.0 * f32(s & 0xF)),
d * (1.0 + 2.0 * f32(s >> 4))
);
for (var k: u32 = 0; k < 2; k++) {
let dl = db[k];
let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);
let sign_w = sign_vals[ib * 2 + k];
for (var l: u32 = 0; l < 4; l++) {
let signs = get_byte(sign_w, l);
let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
for (var j: u32 = 0; j < 4; j++) {
let g1 = get_byte(iq3s_grid[ig1], j);
let g2 = get_byte(iq3s_grid[ig2], j);
let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
sum += dl * f32(g1) * m1 * src1[src1_i];
sum += dl * f32(g2) * m2 * src1[src1_i + 4];
src1_i++;
}
src1_i += 4;
}
}
}
return sum;
}
#endif
#ifdef IQ1_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 8; ib++) {
let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF;
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4);
for (var l: u32 = 0; l < 4; l++) {
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
for (var j: u32 = 0; j < 8; j++) {
let gw = iq1_grid[(ig + j) / 16];
let g = (gw >> (((ig + j) % 16) * 2)) & 3;
let gs = bitcast<i32>(g << 30) >> 30;
sum += dl * (f32(gs) + delta) * src1[src1_i];
src1_i++;
}
}
}
return sum;
}
#endif
#ifdef IQ1_M
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);
let d = f32(bitcast<vec2<f16>>(scale).x);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 8; ib++) {
let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;
let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;
let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;
var dl = array<f32, 2>(
d * f32(2 * s1 + 1),
d * f32(2 * s2 + 1)
);
let qh = block.qh[ib / 2] >> (16 * (ib % 2));
var idx = array<u32, 4>(
get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),
get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),
get_byte(block.qs[ib], 2) | ((qh) & 0x700),
get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)
);
var delta = array<f32, 4>(
select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),
select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),
select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),
select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)
);
for (var l: u32 = 0; l < 4; l++) {
let ig = idx[l] * 8;
for (var j: u32 = 0; j < 8; j++) {
let gw = iq1_grid[(ig + j) / 16];
let g = (gw >> (((ig + j) % 16) * 2)) & 3;
let gs = bitcast<i32>(g << 30) >> 30;
sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i];
src1_i++;
}
}
}
return sum;
}
#endif
#ifdef IQ4_NL
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 32;
var sum = 0.0;
var qs: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4);
}
for (var j: u32 = 0; j < 16; j++) {
let qsb = get_byte(qs[j / 4], j % 4);
sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
src1_i++;
}
return sum;
}
#endif
#ifdef IQ4_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = unpack2x16float(block.d_scales_h)[0];
let scales_h = block.d_scales_h >> 16;
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 8; ib++) {
let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
let dl = d * (f32(ls) - 32.0);
for (var j: u32 = 0; j < 16; j++) {
let iqs = ib * 16 + j;
let qsb = get_byte(block.qs[iqs / 4], iqs % 4);
sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
src1_i++;
}
src1_i += 16;
}
return sum;
}
#endif
struct MulMatParams {
offset_src0: u32, // in elements/blocks
offset_src1: u32, // in elements/blocks
offset_dst: u32, // in elements/blocks
m: u32,
n: u32,
k: u32,
// all strides are in elements/blocks
stride_01: u32,
stride_11: u32,
stride_02: u32,
stride_12: u32,
stride_03: u32,
stride_13: u32,
bs02: u32,
bs03: u32,
broadcast2: u32,
broadcast3: u32
};
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
@group(0) @binding(3) var<uniform> params: MulMatParams;
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>) {
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
let global_idx = wg_linear * 256u + local_id.x;
let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
if (global_idx >= total) {
return;
}
let dst2_stride = params.m * params.n;
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
let dst3_idx = global_idx / dst3_stride;
let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
let src13_idx = dst3_idx; // src1 is not broadcast
let dst3_rem = global_idx % dst3_stride;
let dst2_idx = dst3_rem / dst2_stride;
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
let src12_idx = dst2_idx; // src1 is not broadcast
let dst2_rem = dst3_rem % dst2_stride;
let row = dst2_rem / params.m; // output row
let col = dst2_rem % params.m; // output column
let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
var sum = 0.0;
for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {
sum += multiply_add(src0_idx_base, src1_idx_base, i);
}
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
}

View File

@@ -3,10 +3,18 @@ enable subgroups;
#endif
enable f16;
#ifdef MMVQ
requires packed_4x8_integer_dot_product;
#endif
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#ifdef MMVQ
#include "mul_mat_vec_q_acc.tmpl"
#else
#include "mul_mat_vec_acc.tmpl"
#endif
struct MulMatParams {
offset_src0: u32,
@@ -28,9 +36,14 @@ struct MulMatParams {
};
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>;
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>;
@group(0) @binding(2) var<storage, read_write> dst: array<f32>;
#ifdef MMVQ
@group(0) @binding(1) var<storage, read_write> src1q: array<q8_1>;
#else
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>;
#endif
@group(0) @binding(2) var<storage, read_write> dst: array<f32>;
// "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01
@group(0) @binding(3) var<uniform> params: MulMatParams;
@@ -75,10 +88,15 @@ fn main(
let src12_idx = dst2_idx;
let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base;
#ifdef MMVQ
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u);
let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base);
#else
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
#endif
#ifdef USE_SUBGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {

View File

@@ -436,7 +436,6 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src
}
#endif
#ifdef MUL_ACC_Q3_K
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 110

View File

@@ -0,0 +1,303 @@
#ifdef U32_DEQUANT_HELPERS
#define SRC0_TYPE u32
fn byte_of(v: u32, b: u32) -> u32 {
return (v >> (b * 8u)) & 0xFFu;
}
fn sbyte_of(v: u32, b: u32) -> i32 {
let raw = i32((v >> (b * 8u)) & 0xFFu);
return select(raw, raw - 256, raw >= 128);
}
#endif
#define SRC0_TYPE SRC0_INNER_TYPE
#define SRC1_TYPE SRC1_INNER_TYPE
#ifdef LEGACY_QUANTS
#define BLOCK_SIZE 32
#define THREADS_PER_BLOCK 4
#elif K_QUANTS
#define BLOCK_SIZE 256
#define THREADS_PER_BLOCK 16
#endif
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
#define Q8_BLOCK_SIZE 32
#ifdef MUL_ACC_Q4_0
#define BLOCK_SIZE_BYTES 18
#define B_DS_TYPE vec2<f32>
fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> {
let qs_packed = load_u32_at_src0(block_byte_base + 2u + 4u * inner_id);
return vec2<u32>(
qs_packed & 0x0F0F0F0Fu,
(qs_packed >> 4u) & 0x0F0F0F0Fu
);
}
fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> {
return vec2<u32>(
src1q[block].qs[inner_id],
src1q[block].qs[inner_id + 4u],
);
}
fn repack_b_dm(block: u32) -> B_DS_TYPE {
return B_DS_TYPE(
f32(src1q[block].d),
f32(src1q[block].s)
);
}
fn get_dm(block_byte_base: u32) -> f32 {
return f32(load_f16_at_src0(block_byte_base));
}
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
}
#endif
#ifdef MUL_ACC_Q4_1
#define BLOCK_SIZE_BYTES 20
#define B_DS_TYPE vec2<f32>
fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> {
let qs_packed = load_u32_at_src0(block_byte_base + 4u + 4u * inner_id);
return vec2<u32>(
qs_packed & 0x0F0F0F0Fu,
(qs_packed >> 4u) & 0x0F0F0F0Fu
);
}
fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> {
return vec2<u32>(
src1q[block].qs[inner_id],
src1q[block].qs[inner_id + 4u],
);
}
fn repack_b_dm(block: u32) -> B_DS_TYPE {
return B_DS_TYPE(
f32(src1q[block].d),
f32(src1q[block].s)
);
}
fn get_dm(block_byte_base: u32) -> vec2<f32> {
return vec2<f32>(
f32(load_f16_at_src0(block_byte_base)),
f32(load_f16_at_src0(block_byte_base + 2u))
);
}
fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK;
}
#endif
#ifdef MUL_ACC_Q8_0
#define BLOCK_SIZE_BYTES 34
#define B_DS_TYPE f32
fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> {
return vec2<u32>(
load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u)),
load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u + 1))
);
}
fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> {
return vec2<u32>(
src1q[block].qs[inner_id * 2u],
src1q[block].qs[inner_id * 2u + 1],
);
}
fn repack_b_dm(block: u32) -> B_DS_TYPE {
return B_DS_TYPE(src1q[block].d);
}
fn get_dm(block_byte_base: u32) -> f32 {
return f32(load_f16_at_src0(block_byte_base));
}
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (da * b_ds);
}
#endif
#ifdef LEGACY_QUANTS
fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 {
var row_sum = 0;
let a_repacked = repack_a(a_byte_base, b_inner_id);
row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]);
row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]);
return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds);
}
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
var acc: array<f32, OUTPUTS_PER_WG>;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
let b_inner_id = thread_id % THREADS_PER_BLOCK;
let b_block_idx = src1q_idx_base + block;
let b_repacked = repack_b_qs(b_block_idx, b_inner_id);
let b_ds = repack_b_dm(b_block_idx);
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds);
}
}
}
return acc;
}
#endif
#ifdef MUL_ACC_Q2_K
#define BLOCK_SIZE_BYTES 84
#define B_DS_TYPE f32
fn repack_a(block_byte_base: u32, tid: u32) -> vec4<u32> {
let ih2 = tid / 8u;
let phase = tid % 2u;
let iq4_idx = 2u * ih2 + phase;
let qs_byte_base = block_byte_base + 16u + 16u * iq4_idx;
let qs_shift = tid & 6u;
return vec4<u32>(
(load_u32_at_src0_aligned(qs_byte_base) >> qs_shift) & 0x03030303u,
(load_u32_at_src0_aligned(qs_byte_base + 4u) >> qs_shift) & 0x03030303u,
(load_u32_at_src0_aligned(qs_byte_base + 8u) >> qs_shift) & 0x03030303u,
(load_u32_at_src0_aligned(qs_byte_base + 12u) >> qs_shift) & 0x03030303u,
);
}
fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4<u32> {
let phase = tid % 2u;
return vec4<u32>(
src1q[q8_block_idx].qs[4u * phase],
src1q[q8_block_idx].qs[4u * phase + 1u],
src1q[q8_block_idx].qs[4u * phase + 2u],
src1q[q8_block_idx].qs[4u * phase + 3u],
);
}
fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE {
return B_DS_TYPE(src1q[q8_block_idx].d);
}
fn get_dm(block_byte_base: u32) -> vec2<f32> {
return vec2<f32>(
f32(load_f16_at_src0(block_byte_base + 80u)),
f32(load_f16_at_src0(block_byte_base + 82u)),
);
}
fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
let scale_byte = block_byte_base + tid;
let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u);
return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u));
}
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
let a_repacked = repack_a(a_byte_base, tid);
let dm = get_dm(a_byte_base);
let scale_min = get_scale_min(a_byte_base, tid);
let scale_q = i32(scale_min.x);
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
}
#endif
#ifdef MUL_ACC_Q4_K
#define BLOCK_SIZE_BYTES 144
#define B_DS_TYPE vec2<f32>
fn repack_a(block_byte_base: u32, tid: u32) -> vec4<u32> {
let iq4 = tid / 4u;
let phase = tid % 2u;
let nibble = (tid >> 1u) % 2u;
let q_qs_byte_base = block_byte_base + 16u + 32u * iq4 + 16u * phase;
let qs_shift = 4u * nibble;
return vec4<u32>(
(load_u32_at_src0_aligned(q_qs_byte_base) >> qs_shift) & 0x0F0F0F0Fu,
(load_u32_at_src0_aligned(q_qs_byte_base + 4u) >> qs_shift) & 0x0F0F0F0Fu,
(load_u32_at_src0_aligned(q_qs_byte_base + 8u) >> qs_shift) & 0x0F0F0F0Fu,
(load_u32_at_src0_aligned(q_qs_byte_base + 12u) >> qs_shift) & 0x0F0F0F0Fu,
);
}
fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4<u32> {
let phase = tid % 2u;
return vec4<u32>(
src1q[q8_block_idx].qs[4u * phase],
src1q[q8_block_idx].qs[4u * phase + 1u],
src1q[q8_block_idx].qs[4u * phase + 2u],
src1q[q8_block_idx].qs[4u * phase + 3u],
);
}
fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE {
return B_DS_TYPE(
f32(src1q[q8_block_idx].d),
f32(src1q[q8_block_idx].s),
);
}
fn get_dm(block_byte_base: u32) -> vec2<f32> {
return vec2<f32>(
f32(load_f16_at_src0(block_byte_base + 0u)),
f32(load_f16_at_src0(block_byte_base + 2u)),
);
}
fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
let sc_m_idx = tid / 2u;
let scales_byte_base = block_byte_base + 4u;
let scales0_3 = load_u32_at_src0_aligned(scales_byte_base);
let scales4_7 = load_u32_at_src0_aligned(scales_byte_base + 4u);
let scales8_11 = load_u32_at_src0_aligned(scales_byte_base + 8u);
let byte_idx = sc_m_idx & 3u;
let is_high = sc_m_idx >= 4u;
let sc_low = byte_of(scales0_3, byte_idx) & 0x3Fu;
let sc_high = (byte_of(scales8_11, byte_idx) & 0x0Fu) | ((byte_of(scales0_3, byte_idx) & 0xC0u) >> 2u);
let scale = f32(select(sc_low, sc_high, is_high));
let mn_low = byte_of(scales4_7, byte_idx) & 0x3Fu;
let mn_high = (byte_of(scales8_11, byte_idx) >> 4u) | ((byte_of(scales4_7, byte_idx) & 0xC0u) >> 2u);
let min_val = f32(select(mn_low, mn_high, is_high));
return vec2<f32>(scale, min_val);
}
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
let a_repacked = repack_a(a_byte_base, tid);
let dm = get_dm(a_byte_base);
let scale_min = get_scale_min(a_byte_base, tid);
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
}
#endif
#ifdef K_QUANTS
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
var acc: array<f32, OUTPUTS_PER_WG>;
let tid = thread_id % THREADS_PER_BLOCK;
for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) {
let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
let b_repacked = repack_b_qs(src1q_idx, tid);
let b_ds = repack_b_dm(src1q_idx);
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds);
}
}
}
return acc;
}
#endif

View File

@@ -0,0 +1,173 @@
#ifdef USE_SUBGROUP_REDUCTION
enable subgroups;
#endif
enable f16;
requires packed_4x8_integer_dot_product;
#include "common_decls.tmpl"
struct Params {
offset_src1: u32,
stride_12: u32,
stride_13: u32,
ne0: u32,
ne2: u32,
ne3: u32,
};
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
@group(0) @binding(0) var<storage, read_write> src1: array<SRC1_TYPE>;
@group(0) @binding(1) var<storage, read_write> src1q: array<q8_1>;
@group(0) @binding(2) var<uniform> params: Params;
#ifdef USE_SUBGROUP_REDUCTION
fn cluster_max_8(v: f32) -> f32 {
var r = v;
r = max(r, subgroupShuffleXor(r, 1u));
r = max(r, subgroupShuffleXor(r, 2u));
r = max(r, subgroupShuffleXor(r, 4u));
return r;
}
#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K)
fn cluster_add_i4x8(v: i32) -> i32 {
var r= v;
r += subgroupShuffleXor(r, 1u);
r += subgroupShuffleXor(r, 2u);
r += subgroupShuffleXor(r, 4u);
return r;
}
#endif
#endif
#ifdef USE_WORKGROUP_REDUCTION
#define CLUSTER_SIZE 8
var<workgroup> partial_amaxs: array<array<f32, CLUSTER_SIZE>, WG_SIZE / CLUSTER_SIZE>;
var<workgroup> partial_sums: array<array<i32, CLUSTER_SIZE>, WG_SIZE / CLUSTER_SIZE>;
#endif
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>
) {
let thread_id = local_id.x;
let num_vec4 = params.ne0 / 4u;
let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
let total_batches = wg_per_vec * params.ne2 * params.ne3;
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
if (wg_linear >= total_batches) {
return;
}
let src13_idx = wg_linear / (params.ne2 * wg_per_vec);
let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec;
let src11_wg_idx = wg_linear % wg_per_vec;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
let src1_idx_vec4_base = src1_idx_base / 4u;
let blocks_per_row = params.ne0 / 32u;
let blocks_per_wg = (WG_SIZE * 4u) / 32u;
let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row;
let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u;
let qs_idx = thread_id % 8u;
// reduction
var q4 = vec4<f32>(0.0);
var q4_quants = 0u;
var thread_amax = 0.0;
let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id;
let is_valid = src11_vec4_idx < num_vec4;
#ifdef USE_SUBGROUP_REDUCTION
var d = 0.0;
if (is_valid) {
q4 = src1[src1_idx_vec4_base + src11_vec4_idx];
let abs_q4 = abs(q4);
thread_amax = max(max(abs_q4[0u], abs_q4[1u]), max(abs_q4[2], abs_q4[3]));
}
d = cluster_max_8(thread_amax) / 127.0;
if (is_valid) {
let id = select(0.0, 1.0 / d, d > 0.0);
q4_quants = pack4xI8(vec4<i32>(round(q4 * id)));
if (qs_idx == 0u) {
src1q[src1q_idx].d = f16(d);
}
src1q[src1q_idx].qs[qs_idx] = q4_quants;
}
#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K)
let q4_quants_sum = dot4I8Packed(q4_quants, 0x01010101u);
let s = f16(d * f32(cluster_add_i4x8(q4_quants_sum)));
if (is_valid) {
if (qs_idx == 0u) {
src1q[src1q_idx].s = s;
}
}
#endif
#endif
#ifdef USE_WORKGROUP_REDUCTION
var d = 0.0;
let cluster_id = thread_id / 8u;
if (is_valid) {
q4 = src1[src1_idx_vec4_base + src11_vec4_idx];
let abs_q4 = abs(q4);
thread_amax = max(max(abs_q4[0], abs_q4[1]), max(abs_q4[2], abs_q4[3]));
partial_amaxs[cluster_id][qs_idx] = thread_amax;
}
workgroupBarrier();
if (is_valid) {
let amax = max(
max(
max(partial_amaxs[cluster_id][0], partial_amaxs[cluster_id][1]), max(partial_amaxs[cluster_id][2], partial_amaxs[cluster_id][3])),
max(
max(partial_amaxs[cluster_id][4], partial_amaxs[cluster_id][5]), max(partial_amaxs[cluster_id][6], partial_amaxs[cluster_id][7]))
);
d = amax / 127.0;
let id = select(0.0f, 1.0f / d, d > 0.0f);
q4_quants = pack4xI8(vec4<i32>(round(q4 * id)));
src1q[src1q_idx].qs[qs_idx] = q4_quants;
if (qs_idx == 0u) {
src1q[src1q_idx].d = f16(d);
}
}
#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K)
partial_sums[cluster_id][qs_idx] = dot4I8Packed(q4_quants, 0x01010101u);
workgroupBarrier();
if (is_valid) {
if (qs_idx == 0u) {
let s = d * f32(partial_sums[cluster_id][0] + partial_sums[cluster_id][1] + partial_sums[cluster_id][2] + partial_sums[cluster_id][3]
+ partial_sums[cluster_id][4] + partial_sums[cluster_id][5] + partial_sums[cluster_id][6] + partial_sums[cluster_id][7]);
src1q[src1q_idx].s = f16(s);
}
}
#endif
#endif
}

View File

@@ -505,6 +505,7 @@ class MODEL_ARCH(IntEnum):
LLAMA_EMBED = auto()
MAINCODER = auto()
KIMI_LINEAR = auto()
TALKIE = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -1021,6 +1022,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA_EMBED: "llama-embed",
MODEL_ARCH.MAINCODER: "maincoder",
MODEL_ARCH.KIMI_LINEAR: "kimi-linear",
MODEL_ARCH.TALKIE: "talkie",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -4013,6 +4015,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.TALKIE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LAYER_OUT_SCALE,
],
# TODO
}

View File

@@ -34,6 +34,7 @@ class TensorNameMap:
"encoder", # neobert
"model.transformer.wte", # llada
"embed_tokens", # qwen3-embedding
"model.embed", # talkie
),
# Token type embeddings
@@ -259,6 +260,7 @@ class TensorNameMap:
"model.transformer.blocks.{bid}.q_proj", # llada
"layers.{bid}.self_attn.q_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.q_proj", # nemotron-h
"model.blocks.{bid}.attn.attn_query", # talkie
),
# Attention key
@@ -279,6 +281,7 @@ class TensorNameMap:
"model.transformer.blocks.{bid}.k_proj", # llada
"layers.{bid}.self_attn.k_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.k_proj", # nemotron-h
"model.blocks.{bid}.attn.attn_key", # talkie
),
# Attention value
@@ -298,6 +301,7 @@ class TensorNameMap:
"model.transformer.blocks.{bid}.v_proj", # llada
"layers.{bid}.self_attn.v_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.v_proj", # nemotron-h
"model.blocks.{bid}.attn.attn_value", # talkie
),
# Attention output
@@ -336,6 +340,7 @@ class TensorNameMap:
"layers.{bid}.self_attn.o_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.o_proj", # nemotron-h
"model.layers.{bid}.self_attn.language_expert_dense", # cogvlm
"model.blocks.{bid}.attn.attn_resid", # talkie
),
# Attention output norm
@@ -508,6 +513,7 @@ class TensorNameMap:
"layers.{bid}.mlp.up_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.up_proj", # nemotron-h
"model.layers.{bid}.mlp.language_mlp.up_proj", # cogvlm
"model.blocks.{bid}.mlp.mlp_linear", # talkie
),
MODEL_TENSOR.FFN_UP_EXP: (
@@ -561,6 +567,7 @@ class TensorNameMap:
"model.transformer.blocks.{bid}.ff_proj", # llada
"layers.{bid}.mlp.gate_proj", # qwen3-embedding
"model.layers.{bid}.mlp.language_mlp.gate_proj", # cogvlm
"model.blocks.{bid}.mlp.mlp_gate", # talkie
),
MODEL_TENSOR.FFN_GATE_EXP: (
@@ -636,6 +643,7 @@ class TensorNameMap:
"layers.{bid}.mlp.down_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.down_proj", # nemotron-h
"model.layers.{bid}.mlp.language_mlp.down_proj", # cogvlm
"model.blocks.{bid}.mlp.mlp_resid", # talkie
),
MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -682,6 +690,7 @@ class TensorNameMap:
"model.layers.layers.{bid}.mixer.q_norm", # plamo3
"layers.{bid}.self_attn.q_norm", # qwen3-embedding
"model.layers.{bid}.attention.query_layernorm", # apertus
"model.blocks.{bid}.attn.head_gain.head_g", # talkie
),
MODEL_TENSOR.ATTN_K_NORM: (
@@ -716,6 +725,7 @@ class TensorNameMap:
MODEL_TENSOR.LAYER_OUT_SCALE: (
"model.layers.{bid}.layer_scalar", # gemma4
"model.blocks.{bid}.embed_skip.a_g", # talkie
),
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (

View File

@@ -133,6 +133,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA_EMBED, "llama-embed" },
{ LLM_ARCH_MAINCODER, "maincoder" },
{ LLM_ARCH_KIMI_LINEAR, "kimi-linear" },
{ LLM_ARCH_TALKIE, "talkie" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

View File

@@ -137,6 +137,7 @@ enum llm_arch {
LLM_ARCH_LLAMA_EMBED,
LLM_ARCH_MAINCODER,
LLM_ARCH_KIMI_LINEAR,
LLM_ARCH_TALKIE,
LLM_ARCH_UNKNOWN,
};

View File

@@ -44,6 +44,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
return new llama_model_llama_embed(params);
case LLM_ARCH_MAINCODER:
return new llama_model_maincoder(params);
case LLM_ARCH_TALKIE:
return new llama_model_talkie(params);
case LLM_ARCH_DECI:
return new llama_model_deci(params);
case LLM_ARCH_BAICHUAN:
@@ -2353,6 +2355,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_QWEN3NEXT:
case LLM_ARCH_MIMO2:
case LLM_ARCH_STEP35:
case LLM_ARCH_TALKIE:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL:

View File

@@ -488,7 +488,7 @@ struct llama_layer {
struct ggml_tensor * indexer_attn_k = nullptr;
struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias
// gemma4 layer output scale
// gemma4 layer output scale, reused for talkie embedding skip scale
struct ggml_tensor * out_scale = nullptr;
struct llama_layer_posnet posnet;

View File

@@ -2196,7 +2196,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
} else if (
tokenizer_pre == "gpt-4o" ||
tokenizer_pre == "llama4" ||
tokenizer_pre == "kanana2") {
tokenizer_pre == "kanana2" ||
tokenizer_pre == "talkie") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
clean_spaces = false;
} else if (

View File

@@ -177,9 +177,9 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, model.layers[il].ffn_gate_s,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
@@ -200,7 +200,11 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa
LLM_FFN_SILU, true,
hparams.expert_weights_scale,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il);
il,
nullptr, nullptr,
model.layers[il].ffn_up_exps_s,
model.layers[il].ffn_gate_exps_s,
model.layers[il].ffn_down_exps_s);
cb(cur, "ffn_moe_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);

View File

@@ -186,6 +186,19 @@ struct llama_model_maincoder : public llama_model_base {
};
struct llama_model_talkie : public llama_model_base {
llama_model_talkie(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;
void load_arch_tensors(llama_model_loader & ml) override;
struct graph : public llm_graph_context {
graph(const llama_model & model, const llm_graph_params & params);
};
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
};
struct llama_model_deci : public llama_model_base {
llama_model_deci(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;

149
src/models/talkie.cpp Normal file
View File

@@ -0,0 +1,149 @@
#include "models.h"
void llama_model_talkie::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
switch (hparams.n_layer) {
case 40: type = LLM_TYPE_13B; break;
default: type = LLM_TYPE_UNKNOWN;
}
}
void llama_model_talkie::load_arch_tensors(llama_model_loader &) {
LLAMA_LOAD_LOCALS;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
// no k gain
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {1, n_head}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1}, 0);
}
}
std::unique_ptr<llm_graph_context> llama_model_talkie::build_arch_graph(const llm_graph_params & params) const {
return std::make_unique<graph>(*this, params);
}
llama_model_talkie::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
inpL = build_norm(inpL, nullptr, nullptr, LLM_NORM_RMS, -1);
cb(inpL, "inp_norm", -1);
ggml_tensor * embd_skip = inpL;
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
ggml_tensor * inp_skip = embd_skip;
cur = build_norm(inpL, nullptr, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
{
auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur,
n_embd_head, n_head, n_head_kv, il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
// reference applies qknorm after rope
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_norm", il);
Kcur = build_norm(Kcur, nullptr, nullptr, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_norm", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn,
model.layers[il].wo, nullptr, model.layers[il].wo_s,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
inp_skip = ggml_get_rows(ctx0, inp_skip, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
cur = build_norm(ffn_inp, nullptr, nullptr, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, nullptr, nullptr,
model.layers[il].ffn_gate, nullptr, nullptr,
model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
ggml_tensor * skip = ggml_mul(ctx0, inp_skip, model.layers[il].out_scale);
cb(skip, "embd_skip", il);
cur = ggml_add(ctx0, cur, skip);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur, nullptr, nullptr, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = ggml_scale(ctx0, cur, hparams.f_logit_scale);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}

View File

@@ -21,6 +21,7 @@
#include <ggml-cpp.h>
#include <algorithm>
#include <atomic>
#include <array>
#include <cfloat>
#include <cinttypes>
@@ -33,6 +34,7 @@
#include <future>
#include <fstream>
#include <memory>
#include <mutex>
#include <random>
#include <regex>
#include <set>
@@ -55,33 +57,24 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
{
// parallel initialization
static const size_t n_threads = N_THREADS;
// static RNG initialization (revisit if n_threads stops being constant)
static std::vector<std::default_random_engine> generators = []() {
std::random_device rd;
std::vector<std::default_random_engine> vec;
vec.reserve(n_threads);
//for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed
for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }
return vec;
}();
auto init_thread = [&](size_t ith, size_t start, size_t end) {
auto init_thread = [&](size_t start, size_t end) {
thread_local std::default_random_engine gen(std::random_device{}());
std::uniform_real_distribution<float> distribution(min, max);
auto & gen = generators[ith];
for (size_t i = start; i < end; i++) {
data[i] = distribution(gen);
}
};
if (n_threads == 1) {
init_thread(0, 0, nels);
init_thread(0, nels);
} else {
std::vector<std::future<void>> tasks;
tasks.reserve(n_threads);
for (size_t i = 0; i < n_threads; i++) {
size_t start = i*nels/n_threads;
size_t end = (i+1)*nels/n_threads;
tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));
tasks.push_back(std::async(std::launch::async, init_thread, start, end));
}
for (auto & t : tasks) {
t.get();
@@ -516,6 +509,25 @@ static bool output_format_from_str(const std::string & s, output_formats & forma
return true;
}
static std::string test_time_now() {
time_t t = time(NULL);
struct tm tm_buf;
#ifdef _WIN32
if (gmtime_s(&tm_buf, &t) != 0) {
return "";
}
#else
if (gmtime_r(&t, &tm_buf) == nullptr) {
return "";
}
#endif
char buf[32];
if (std::strftime(buf, sizeof(buf), "%FT%TZ", &tm_buf) == 0) {
return "";
}
return buf;
}
// Test result structure for SQL output
struct test_result {
std::string test_time;
@@ -545,11 +557,7 @@ struct test_result {
supported = false;
passed = false;
// Set test time
time_t t = time(NULL);
char buf[32];
std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
test_time = buf;
test_time = test_time_now();
// Set build info
build_commit = ggml_commit();
@@ -573,11 +581,7 @@ struct test_result {
n_runs(n_runs),
device_description(device_description),
backend_reg_name(backend_reg_name) {
// Set test time
time_t t = time(NULL);
char buf[32];
std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
test_time = buf;
test_time = test_time_now();
// Set build info
build_commit = ggml_commit();
@@ -1110,6 +1114,17 @@ static std::unique_ptr<printer> create_printer(output_formats format) {
GGML_ABORT("invalid output format");
}
static std::mutex g_test_output_mutex;
static void print_test_result_locked(printer * output_printer, const test_result & result) {
if (output_printer == nullptr) {
return;
}
std::lock_guard<std::mutex> guard(g_test_output_mutex);
output_printer->print_test_result(result);
}
struct test_case {
virtual ~test_case() {}
@@ -1338,9 +1353,7 @@ struct test_case {
test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test",
false, false, "not supported");
if (output_printer) {
output_printer->print_test_result(result);
}
print_test_result_locked(output_printer, result);
ggml_free(ctx);
return test_status_t::NOT_SUPPORTED;
@@ -1462,9 +1475,7 @@ struct test_case {
test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test", supported, test_passed,
error_msg);
if (output_printer) {
output_printer->print_test_result(result);
}
print_test_result_locked(output_printer, result);
return test_passed ? test_status_t::OK : test_status_t::FAIL;
}
@@ -9493,8 +9504,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_from_file(const c
return test_cases;
}
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter,
printer * output_printer, const char * test_file_path) {
static bool test_backend(ggml_backend_t backend, ggml_backend_dev_t dev, test_mode mode, const char * op_names_filter, const char * params_filter,
printer * output_printer, const char * test_file_path, int parallel_workers) {
auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
if (params_filter == nullptr) {
return;
@@ -9547,21 +9558,90 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
set_use_ref(backend_cpu, true);
}
size_t n_ok = 0;
size_t tests_run = 0;
std::atomic<size_t> n_ok = 0;
std::atomic<size_t> tests_run = 0;
std::vector<std::string> failed_tests;
for (auto & test : test_cases) {
test_status_t status = test->eval(backend, backend_cpu, op_names_filter, output_printer);
if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) {
continue;
std::mutex failed_tests_mutex;
// Each worker grabs a chunk of cases at a time. The chunk shrinks as we
// run out of work so that a few slow tests at the tail get spread across
// workers instead of landing on one unlucky thread.
constexpr size_t MAX_TESTS_PER_ITER = 100;
std::atomic<size_t> test_idx = 0;
const auto & next_chunk = [&](size_t & my_begin, size_t & my_end) {
const size_t cur = test_idx.load(std::memory_order_relaxed);
const size_t remaining = cur < test_cases.size() ? test_cases.size() - cur : 0;
const size_t chunk = std::max<size_t>(1, std::min<size_t>(MAX_TESTS_PER_ITER, remaining / parallel_workers));
my_begin = test_idx.fetch_add(chunk);
my_end = std::min(my_begin + chunk, test_cases.size());
};
const auto & run_tests = [&](ggml_backend_t b, ggml_backend_t b_cpu) {
size_t my_begin, my_end;
next_chunk(my_begin, my_end);
while (my_begin < test_cases.size()) {
for (size_t i = my_begin; i < my_end; ++i) {
auto & test = test_cases[i];
test_status_t status = test->eval(b, b_cpu, op_names_filter, output_printer);
if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) {
continue;
}
tests_run++;
if (status == test_status_t::OK) {
n_ok++;
} else if (status == test_status_t::FAIL) {
std::lock_guard<std::mutex> guard(failed_tests_mutex);
failed_tests.push_back(test->current_op_name + "(" + test->vars() + ")");
}
}
next_chunk(my_begin, my_end);
}
tests_run++;
if (status == test_status_t::OK) {
n_ok++;
} else if (status == test_status_t::FAIL) {
failed_tests.push_back(test->current_op_name + "(" + test->vars() + ")");
};
if (parallel_workers <= 1) {
// Reuse the outer backend / backend_cpu so we don't pay an
// extra CPU backend init.
run_tests(backend, backend_cpu);
} else {
std::atomic<size_t> workers_started = 0;
const auto & eval_worker = [&]() {
ggml_backend_t b = ggml_backend_dev_init(dev, NULL);
if (b == NULL) {
return;
}
ggml_backend_t b_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
if (b_cpu == NULL) {
ggml_backend_free(b);
return;
}
if (set_use_ref) {
set_use_ref(b_cpu, true);
}
workers_started++;
run_tests(b, b_cpu);
ggml_backend_free(b_cpu);
ggml_backend_free(b);
};
std::vector<std::thread> threads;
threads.reserve(parallel_workers);
for (int i = 0; i < parallel_workers; ++i) {
threads.emplace_back(eval_worker);
}
for (auto & t : threads) {
t.join();
}
if (workers_started == 0 && !test_cases.empty()) {
ggml_backend_free(backend_cpu);
return false;
}
}
output_printer->print_summary(test_summary_info(n_ok, tests_run, false));
output_printer->print_failed_tests(failed_tests);
@@ -9709,7 +9789,7 @@ static void show_test_coverage() {
static void usage(char ** argv) {
printf("Usage: %s [mode] [-o <op,..>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>] [--list-ops]", argv[0]);
printf(" [--show-coverage] [--test-file <path>]\n");
printf(" [--show-coverage] [--test-file <path>] [-j <n>]\n");
printf(" valid modes:\n");
printf(" - test (default, compare with CPU backend for correctness)\n");
printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
@@ -9721,6 +9801,7 @@ static void usage(char ** argv) {
printf(" --list-ops lists all available GGML operations\n");
printf(" --show-coverage shows test coverage\n");
printf(" --test-file reads test operators from a test file generated by llama-export-graph-ops\n");
printf(" -j <n> runs tests using <n> parallel worker threads (default: 1, test mode only)\n");
}
int main(int argc, char ** argv) {
@@ -9730,6 +9811,7 @@ int main(int argc, char ** argv) {
const char * backend_filter = nullptr;
const char * params_filter = nullptr;
const char * test_file_path = nullptr;
int parallel_workers = 1;
for (int i = 1; i < argc; i++) {
if (strcmp(argv[i], "test") == 0) {
@@ -9784,6 +9866,17 @@ int main(int argc, char ** argv) {
usage(argv);
return 1;
}
} else if (strcmp(argv[i], "-j") == 0) {
if (i + 1 < argc) {
parallel_workers = atoi(argv[++i]);
if (parallel_workers < 1) {
usage(argv);
return 1;
}
} else {
usage(argv);
return 1;
}
} else {
usage(argv);
return 1;
@@ -9836,7 +9929,7 @@ int main(int argc, char ** argv) {
false, "", ggml_backend_dev_description(dev),
total / 1024 / 1024, free / 1024 / 1024, true));
bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get(), test_file_path);
bool ok = test_backend(backend, dev, mode, op_names_filter, params_filter, output_printer.get(), test_file_path, parallel_workers);
if (ok) {
n_ok++;