mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-05-28 17:27:26 +03:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dbe9c0c8ce | ||
|
|
6fe90deffa | ||
|
|
581d020b12 | ||
|
|
7623de11d9 | ||
|
|
c9d98295a3 | ||
|
|
1506d39e76 | ||
|
|
54121f7325 | ||
|
|
192d8ae8b8 |
@@ -74,6 +74,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"Gemma3nForCausalLM": "gemma",
|
||||
"Gemma3nForConditionalGeneration": "gemma",
|
||||
"Gemma4ForConditionalGeneration": "gemma",
|
||||
"Gemma4ForCausalLM": "gemma",
|
||||
"GemmaForCausalLM": "gemma",
|
||||
"Glm4ForCausalLM": "glm",
|
||||
"Glm4MoeForCausalLM": "glm",
|
||||
@@ -215,6 +216,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"T5EncoderModel": "t5",
|
||||
"T5ForConditionalGeneration": "t5",
|
||||
"T5WithLMHeadModel": "t5",
|
||||
"TalkieForCausalLM": "talkie",
|
||||
"UMT5ForConditionalGeneration": "t5",
|
||||
"UMT5Model": "t5",
|
||||
"UltravoxModel": "ultravox",
|
||||
|
||||
@@ -1622,6 +1622,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "62f6fb0a6fd5098caeabb19b07a5c1099cafc8b9c40eab6ea89ece4ec02fbc57":
|
||||
# ref: https://huggingface.co/sarvamai/sarvam-30b
|
||||
res = "sarvam-moe"
|
||||
if chkhsh == "f728162c1315c26e40249849799b4ba3fe584c32084b4795b03eb295e63cb5af":
|
||||
# ref: https://huggingface.co/lewtun/talkie-1930-13b-it-hf
|
||||
res = "talkie"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
||||
@@ -614,7 +614,7 @@ class Gemma3NModel(Gemma3Model):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Gemma4ForConditionalGeneration")
|
||||
@ModelBase.register("Gemma4ForConditionalGeneration", "Gemma4ForCausalLM")
|
||||
class Gemma4Model(Gemma3Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA4
|
||||
|
||||
|
||||
53
conversion/talkie.py
Normal file
53
conversion/talkie.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import Tensor
|
||||
|
||||
from .base import LazyTorchTensor, ModelBase, TextModel, gguf
|
||||
|
||||
|
||||
@ModelBase.register("TalkieForCausalLM")
|
||||
class TalkieModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.TALKIE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
# Talkie used F.rms_norm without an explicit eps
|
||||
self.gguf_writer.add_layer_norm_rms_eps(torch.finfo(torch.float32).eps)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
prefix = f"model.blocks.{bid}." if bid is not None else ""
|
||||
suffix = name.removeprefix(prefix)
|
||||
|
||||
if suffix == "attn_gain.a_g":
|
||||
yield self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid, ".scale"), data_torch
|
||||
return
|
||||
elif suffix == "mlp_gain.a_g":
|
||||
yield self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid, ".scale"), data_torch
|
||||
return
|
||||
elif suffix == "lm_head_gain.w_g":
|
||||
self.gguf_writer.add_logit_scale(LazyTorchTensor.to_eager(data_torch).item())
|
||||
return
|
||||
elif suffix in ("attn.attn_query.weight", "attn.attn_key.weight"):
|
||||
# absorb inverse rope
|
||||
head_dim = self.hparams["head_dim"]
|
||||
shape = data_torch.shape
|
||||
data_torch = torch.reshape(data_torch, (-1, head_dim, shape[-1]))
|
||||
signs = torch.ones((1, head_dim, 1), dtype=data_torch.dtype)
|
||||
signs[:, head_dim // 2 :, :] = -1
|
||||
if self.lazy:
|
||||
signs = LazyTorchTensor.from_eager(signs)
|
||||
# (n_head, head_dim, n_in) -> (n_out, n_in)
|
||||
data_torch = torch.reshape(data_torch * signs, shape)
|
||||
elif suffix == "attn.head_gain.head_g":
|
||||
# allow head gain to broadcast
|
||||
data_torch = data_torch.unsqueeze(-1)
|
||||
|
||||
if not name.endswith(".weight"):
|
||||
name += ".weight"
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
@@ -156,6 +156,7 @@ models = [
|
||||
{"name": "kanana2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601", },
|
||||
{"name": "f2llmv2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/codefuse-ai/F2LLM-v2-4B", },
|
||||
{"name": "sarvam-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sarvamai/sarvam-30b", },
|
||||
{"name": "talkie", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/lewtun/talkie-1930-13b-it-hf", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
|
||||
@@ -208,6 +208,16 @@ class LoraTorchTensor:
|
||||
def to(self, *args, **kwargs):
|
||||
return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
|
||||
|
||||
def __mul__(self, other) -> LoraTorchTensor:
|
||||
# Only output-side multiplication for now
|
||||
# W = B @ A, so M_out * W == (M_out * B) @ A
|
||||
if not isinstance(other, (int, float)) and other.shape and other.shape[-1] != 1:
|
||||
raise NotImplementedError
|
||||
return LoraTorchTensor(self._lora_A, self._lora_B * other)
|
||||
|
||||
def __rmul__(self, other) -> LoraTorchTensor:
|
||||
return self * other
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
|
||||
del types # unused
|
||||
|
||||
@@ -743,6 +743,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because SYCL Graph is still on development, no better performance. |
|
||||
| GGML_SYCL_ENABLE_LEVEL_ZERO | 1 (default) or 0 | Use Level Zero API for device memory allocation instead of SYCL. Reduces system RAM usage on Intel dGPUs by avoiding DMA-buf/TTM host memory staging. Requires GGML_SYCL_SUPPORT_LEVEL_ZERO=ON at build time. |
|
||||
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
|
||||
| GGML_SYCL_ENABLE_VMM | 0 or 1 (default) | Enable the virtual-memory device pool. |
|
||||
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
|
||||
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Allow SYCL/Unified Runtime Level Zero device allocations larger than 4 GiB. llama.cpp's direct Level Zero allocation path requests the relaxed maximum-size limit itself when GGML_SYCL_ENABLE_LEVEL_ZERO=1. |
|
||||
|
||||
@@ -753,6 +754,7 @@ Pass these via `CXXFLAGS` or add a one-off `#define` to enable a flag on the spo
|
||||
| Name | Function |
|
||||
|-----------------|----------------------------------------------------------------------------------|
|
||||
| DEBUG_SYCL_POOL | Enable device memory pool logging on teardown. Useful for profiling allocations. |
|
||||
| DEBUG_SYCL_MALLOC | Enable verbose per-call logging of device pool alloc/free operations. |
|
||||
|
||||
## Design Rule
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ __global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows,
|
||||
float reg[el_w];
|
||||
const int lane = threadIdx.x;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < el_w; ++i) {
|
||||
reg[i] = src[i * warp_size + lane] * scale;
|
||||
@@ -57,10 +58,11 @@ __global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_are_same_shape(src, dst));
|
||||
GGML_ASSERT(ggml_is_contiguous(src));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
if (!ggml_is_contiguous(src) || !ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
const int n = src->ne[0];
|
||||
const int64_t rows = ggml_nrows(src);
|
||||
|
||||
@@ -68,7 +70,6 @@ void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src,
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||
GGML_ASSERT(n % warp_size == 0);
|
||||
const int rows_per_block = 4;
|
||||
|
||||
const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block;
|
||||
@@ -83,26 +84,18 @@ void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src,
|
||||
|
||||
switch (n) {
|
||||
case 64:
|
||||
{
|
||||
ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale);
|
||||
break;
|
||||
}
|
||||
ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale);
|
||||
return true;
|
||||
case 128:
|
||||
{
|
||||
ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale);
|
||||
break;
|
||||
}
|
||||
ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale);
|
||||
return true;
|
||||
case 256:
|
||||
{
|
||||
ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale);
|
||||
break;
|
||||
}
|
||||
ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale);
|
||||
return true;
|
||||
case 512:
|
||||
{
|
||||
ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale);
|
||||
break;
|
||||
}
|
||||
ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale);
|
||||
return true;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst);
|
||||
// Returns whether the Fast Walsh-Hadamard transform could be used.
|
||||
bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst);
|
||||
|
||||
@@ -2596,9 +2596,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
|
||||
|
||||
const int32_t hint = ggml_get_op_params_i32(dst, 1);
|
||||
if (hint == GGML_HINT_SRC0_IS_HADAMARD) {
|
||||
GGML_ASSERT(!split);
|
||||
ggml_cuda_op_fwht(ctx, src1, dst);
|
||||
if (hint == GGML_HINT_SRC0_IS_HADAMARD && !split && ggml_cuda_op_fwht(ctx, src1, dst)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -224,6 +224,7 @@ struct sycl_device_info {
|
||||
int max_wg_per_cu; // max work groups per compute unit - refer to
|
||||
// cudaOccupancyMaxActiveBlocksPerMultiprocessor
|
||||
bool vmm; // virtual memory support
|
||||
size_t vmm_granularity; // granularity of virtual memory
|
||||
size_t total_vram;
|
||||
sycl_hw_info hw_info;
|
||||
optimize_feature opt_feature;
|
||||
@@ -244,6 +245,8 @@ struct ggml_sycl_device_info {
|
||||
|
||||
const ggml_sycl_device_info & ggml_sycl_info();
|
||||
|
||||
static constexpr size_t SYCL_BUFFER_ALIGNMENT = 128;
|
||||
|
||||
struct ggml_sycl_pool {
|
||||
virtual ~ggml_sycl_pool() = default;
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include <cstdlib>
|
||||
#include <float.h>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
@@ -37,6 +38,11 @@
|
||||
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
|
||||
# include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
|
||||
#endif
|
||||
#if SYCL_EXT_ONEAPI_VIRTUAL_MEM
|
||||
# include <sycl/ext/oneapi/virtual_mem/physical_mem.hpp>
|
||||
# include <sycl/ext/oneapi/virtual_mem/virtual_mem.hpp>
|
||||
# define GGML_SYCL_USE_VMM
|
||||
#endif
|
||||
#include <sycl/half_type.hpp>
|
||||
|
||||
#include "ggml.h"
|
||||
@@ -70,6 +76,7 @@ int g_ggml_sycl_debug = 0;
|
||||
int g_ggml_sycl_disable_optimize = 0;
|
||||
int g_ggml_sycl_disable_graph = 0;
|
||||
int g_ggml_sycl_disable_dnn = 0;
|
||||
int g_ggml_sycl_enable_vmm = 1;
|
||||
int g_ggml_sycl_prioritize_dmmv = 0;
|
||||
int g_ggml_sycl_use_async_mem_op = 0;
|
||||
int g_ggml_sycl_use_async_mem_op_requested = 1;
|
||||
@@ -96,13 +103,30 @@ static ggml_sycl_device_info ggml_sycl_init() {
|
||||
// GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
|
||||
// #endif
|
||||
for (int i = 0; i < info.device_count; ++i) {
|
||||
info.devices[i].vmm = 0;
|
||||
dpct::device_info prop;
|
||||
auto & device = dpct::dev_mgr::instance().get_device(i);
|
||||
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
|
||||
prop, device)));
|
||||
|
||||
#if !defined(GGML_SYCL_USE_VMM)
|
||||
info.devices[i].vmm = 0;
|
||||
#else
|
||||
info.devices[i].vmm = device.has(sycl::aspect::ext_oneapi_virtual_mem);
|
||||
if (info.devices[i].vmm) {
|
||||
// NB: SYCL's get_mem_granularity always returns the _minimum_ granularity,
|
||||
// but the L0 API requires a larger page size for allocs above 2 MiB and
|
||||
// rejects non-multiples with UR_RESULT_ERROR_INVALID_VALUE [sic].
|
||||
// Here we clamp it to 2 MiB for simplicity, but other devices may require
|
||||
// calling zeVirtualMemQueryPageSize or yet unexposed public API.
|
||||
const size_t physical_page = 2ull << 20; // 2 MiB
|
||||
info.devices[i].vmm_granularity = std::max<size_t>(
|
||||
sycl::ext::oneapi::experimental::get_mem_granularity(
|
||||
device, sycl::context(device)),
|
||||
physical_page);
|
||||
}
|
||||
#endif
|
||||
|
||||
info.default_tensor_split[i] = total_vram;
|
||||
total_vram += prop.get_global_mem_size();
|
||||
|
||||
@@ -234,6 +258,7 @@ static void ggml_check_sycl() try {
|
||||
g_ggml_sycl_disable_optimize = get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
|
||||
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
||||
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
|
||||
g_ggml_sycl_enable_vmm = get_sycl_env("GGML_SYCL_ENABLE_VMM", 1);
|
||||
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
|
||||
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO
|
||||
g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", ggml_sycl_info().ext_oneapi_level_zero);
|
||||
@@ -275,6 +300,11 @@ static void ggml_check_sycl() try {
|
||||
#else
|
||||
GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: no\n");
|
||||
#endif
|
||||
#if defined(GGML_SYCL_USE_VMM)
|
||||
GGML_LOG_INFO(" GGML_SYCL_USE_VMM: yes\n");
|
||||
#else
|
||||
GGML_LOG_INFO(" GGML_SYCL_USE_VMM: no\n");
|
||||
#endif
|
||||
|
||||
GGML_LOG_INFO("Running with Environment Variables:\n");
|
||||
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
||||
@@ -293,6 +323,11 @@ static void ggml_check_sycl() try {
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
|
||||
#else
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
|
||||
#endif
|
||||
#if defined(GGML_SYCL_USE_VMM)
|
||||
GGML_LOG_INFO(" GGML_SYCL_ENABLE_VMM: %d\n", g_ggml_sycl_enable_vmm);
|
||||
#else
|
||||
GGML_LOG_INFO(" GGML_SYCL_ENABLE_VMM: virtual memory extension is not available\n");
|
||||
#endif
|
||||
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
|
||||
g_ggml_sycl_use_async_mem_op_requested = get_sycl_env("GGML_SYCL_USE_ASYNC_MEM_OP", 1);
|
||||
@@ -754,7 +789,7 @@ catch (sycl::exception const &exc) {
|
||||
}
|
||||
|
||||
static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
return 128;
|
||||
return SYCL_BUFFER_ALIGNMENT;
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
@@ -1177,7 +1212,7 @@ static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(gg
|
||||
}
|
||||
|
||||
static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
return 128;
|
||||
return SYCL_BUFFER_ALIGNMENT;
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
@@ -1462,6 +1497,121 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
|
||||
}
|
||||
};
|
||||
|
||||
// pool with virtual memory management
|
||||
#if defined(GGML_SYCL_USE_VMM)
|
||||
struct ggml_sycl_pool_vmm : public ggml_sycl_pool {
|
||||
static const size_t SYCL_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
|
||||
|
||||
int device;
|
||||
sycl::context ctx;
|
||||
sycl::device dev;
|
||||
|
||||
uintptr_t pool_addr = 0;
|
||||
size_t pool_used = 0;
|
||||
size_t pool_size = 0;
|
||||
size_t granularity;
|
||||
|
||||
// physical_mem owns the commits (unlike cuMemMap)
|
||||
struct mapping {
|
||||
sycl::ext::oneapi::experimental::physical_mem phys;
|
||||
void * map_ptr;
|
||||
};
|
||||
std::vector<mapping> mappings;
|
||||
|
||||
explicit ggml_sycl_pool_vmm(queue_ptr qptr_, int device_) :
|
||||
device(device_),
|
||||
ctx(qptr_->get_context()),
|
||||
dev(qptr_->get_device()),
|
||||
granularity(ggml_sycl_info().devices[device_].vmm_granularity) {
|
||||
}
|
||||
|
||||
~ggml_sycl_pool_vmm() {
|
||||
if (pool_addr == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Per spec, unmap must (a) match the exact (ptr, size) of an earlier
|
||||
// physical_mem::map() call and (b) precede destruction of the
|
||||
// physical_mem objects (their dtors won't unmap).
|
||||
for (auto & m : mappings) {
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(sycl::ext::oneapi::experimental::unmap(
|
||||
m.map_ptr, m.phys.size(), ctx)));
|
||||
}
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(sycl::ext::oneapi::experimental::free_virtual_mem(
|
||||
pool_addr, SYCL_POOL_VMM_MAX_SIZE, ctx)));
|
||||
}
|
||||
|
||||
void * alloc(size_t size, size_t * actual_size) override {
|
||||
// round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
|
||||
size = GGML_PAD(size, SYCL_BUFFER_ALIGNMENT);
|
||||
|
||||
size_t avail = pool_size - pool_used;
|
||||
|
||||
if (size > avail) {
|
||||
// round up to the next multiple of the granularity
|
||||
size_t reserve_size = GGML_PAD(size - avail, granularity);
|
||||
|
||||
GGML_ASSERT(pool_size + reserve_size <= SYCL_POOL_VMM_MAX_SIZE);
|
||||
|
||||
// allocate more physical memory
|
||||
std::optional<sycl::ext::oneapi::experimental::physical_mem> phys;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(phys.emplace(dev, ctx, reserve_size)));
|
||||
|
||||
// reserve virtual address space (if not already reserved)
|
||||
if (pool_addr == 0) {
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(
|
||||
pool_addr = sycl::ext::oneapi::experimental::reserve_virtual_mem(
|
||||
SYCL_POOL_VMM_MAX_SIZE, ctx)));
|
||||
}
|
||||
|
||||
// map at the end of the pool
|
||||
void * map_ptr = nullptr;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(
|
||||
map_ptr = phys->map(pool_addr + pool_size, reserve_size,
|
||||
sycl::ext::oneapi::experimental::address_access_mode::read_write)));
|
||||
|
||||
// stash these so we could unmap this exact range in dtor
|
||||
mappings.push_back({
|
||||
std::move(*phys),
|
||||
map_ptr,
|
||||
});
|
||||
|
||||
// add to the pool
|
||||
pool_size += reserve_size;
|
||||
|
||||
#ifdef DEBUG_SYCL_MALLOC
|
||||
GGML_LOG_INFO("sycl pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
|
||||
device, (unsigned long long) (pool_size/1024/1024),
|
||||
(unsigned long long) (reserve_size/1024/1024));
|
||||
#endif
|
||||
}
|
||||
|
||||
GGML_ASSERT(pool_addr != 0);
|
||||
|
||||
void * ptr = reinterpret_cast<void *>(pool_addr + pool_used);
|
||||
*actual_size = size;
|
||||
pool_used += size;
|
||||
|
||||
#ifdef DEBUG_SYCL_MALLOC
|
||||
GGML_LOG_INFO("sycl pool[%d]: allocated %llu bytes at %p\n", device, (unsigned long long) size, ptr);
|
||||
#endif
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void free(void * ptr, size_t size) override {
|
||||
#ifdef DEBUG_SYCL_MALLOC
|
||||
GGML_LOG_INFO("sycl pool[%d]: freed %llu bytes at %p\n", device, (unsigned long long) size, ptr);
|
||||
#endif
|
||||
|
||||
pool_used -= size;
|
||||
|
||||
// all deallocations must be in reverse order of the allocations
|
||||
GGML_ASSERT(ptr == reinterpret_cast<void *>(pool_addr + pool_used));
|
||||
}
|
||||
};
|
||||
#endif // defined(GGML_SYCL_USE_VMM)
|
||||
|
||||
struct ggml_sycl_pool_host : public ggml_sycl_pool {
|
||||
queue_ptr qptr;
|
||||
int device;
|
||||
@@ -1542,20 +1692,19 @@ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(que
|
||||
}
|
||||
|
||||
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
|
||||
// TBD: NO VMM support
|
||||
// if (ggml_sycl_info().devices[device].vmm) {
|
||||
// return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(device));
|
||||
// }
|
||||
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device));
|
||||
#if defined(GGML_SYCL_USE_VMM)
|
||||
if (g_ggml_sycl_enable_vmm && ggml_sycl_info().devices[device].vmm) {
|
||||
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(qptr, device));
|
||||
}
|
||||
#endif // defined(GGML_SYCL_USE_VMM)
|
||||
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device));
|
||||
}
|
||||
|
||||
|
||||
std::unique_ptr<ggml_sycl_fattn_kv_buffers> ggml_backend_sycl_context::new_fattn_kv_buffers(queue_ptr qptr, int device) {
|
||||
return std::unique_ptr<ggml_sycl_fattn_kv_buffers>(new ggml_sycl_fattn_kv_buffers(qptr, device));
|
||||
}
|
||||
|
||||
// TBD pool with virtual memory management
|
||||
// struct ggml_sycl_pool_vmm : public ggml_sycl_pool
|
||||
|
||||
/// kernels
|
||||
typedef void (*ggml_sycl_op_mul_mat_t)(
|
||||
ggml_backend_sycl_context & ctx,
|
||||
|
||||
@@ -52,7 +52,7 @@
|
||||
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4
|
||||
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4
|
||||
|
||||
// default size for legacy matrix multiplication
|
||||
// default size for reg-tile matrix multiplication
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE 256
|
||||
|
||||
// Same hash combine function as in boost
|
||||
@@ -93,6 +93,8 @@ struct ggml_webgpu_shader_lib_context {
|
||||
uint32_t sg_mat_k = 0;
|
||||
uint32_t min_subgroup_size = 0;
|
||||
uint32_t max_subgroup_size = 0;
|
||||
bool supports_dot_product = false;
|
||||
std::string vendor;
|
||||
};
|
||||
|
||||
struct webgpu_pipeline {
|
||||
@@ -850,31 +852,15 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
||||
|
||||
/** Matrix Multiplication **/
|
||||
|
||||
struct ggml_webgpu_legacy_mul_mat_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
|
||||
bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_mul_mat_vec_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
int vectorized;
|
||||
bool use_mmvq;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized;
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
|
||||
use_mmvq == other.use_mmvq;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -884,6 +870,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
ggml_webgpu_hash_combine(seed, key.use_mmvq);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
@@ -894,6 +881,20 @@ struct ggml_webgpu_mul_mat_vec_shader_decisions {
|
||||
uint32_t vec_size;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_quantize_q8_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
|
||||
bool operator==(const ggml_webgpu_quantize_q8_pipeline_key & other) const { return src0_type == other.src0_type; }
|
||||
};
|
||||
|
||||
struct ggml_webgpu_quantize_q8_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_quantize_q8_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_mul_mat_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
@@ -1051,6 +1052,36 @@ struct ggml_webgpu_soft_max_pipeline_key_hash {
|
||||
}
|
||||
};
|
||||
|
||||
/** MMVQ **/
|
||||
|
||||
inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
bool supports_dot_product,
|
||||
const std::string & vendor) {
|
||||
if (src1->ne[1] == 1) {
|
||||
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
|
||||
if (supports_dp4a && supports_dot_product) {
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F32:
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
return src0->ne[0] % 4 == 0;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
class ggml_webgpu_shader_lib {
|
||||
wgpu::Device device;
|
||||
pre_wgsl::Preprocessor preprocessor;
|
||||
@@ -1099,14 +1130,12 @@ class ggml_webgpu_shader_lib {
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_flash_attn_blk_pipeline_key_hash>
|
||||
flash_attn_blk_pipelines;
|
||||
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
|
||||
mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec)
|
||||
std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
|
||||
mul_mat_vec_pipelines; // fast mat-vec (n==1)
|
||||
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
|
||||
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
|
||||
std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash>
|
||||
quantize_q8_pipelines;
|
||||
std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
|
||||
std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
|
||||
mul_mat_id_pipelines; // src0_type/src1_type
|
||||
@@ -1631,7 +1660,7 @@ class ggml_webgpu_shader_lib {
|
||||
key.type = context.dst->type;
|
||||
key.d_state = (int) context.src0->ne[0];
|
||||
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
|
||||
ggml_webgpu_tensor_overlap(context.src1, context.src5);
|
||||
ggml_webgpu_tensor_overlap(context.src1, context.src5);
|
||||
|
||||
auto it = ssm_scan_pipelines.find(key);
|
||||
if (it != ssm_scan_pipelines.end()) {
|
||||
@@ -1744,6 +1773,44 @@ class ggml_webgpu_shader_lib {
|
||||
return pad_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_quantize_q8_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_quantize_q8_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
|
||||
auto it = quantize_q8_pipelines.find(key);
|
||||
if (it != quantize_q8_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
const char * shader_src = wgsl_quantize_q8;
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "quantize_q8";
|
||||
|
||||
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
||||
|
||||
defines.push_back("SRC1_INNER_TYPE=f32");
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
|
||||
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
||||
std::string src0_name = src0_traits->type_name;
|
||||
std::string type_upper = src0_name;
|
||||
variant += "_" + src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
defines.push_back("Q8_1_T");
|
||||
|
||||
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
|
||||
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
|
||||
|
||||
auto processed = preprocessor.preprocess(shader_src, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
quantize_q8_pipelines[key] = pipeline;
|
||||
return quantize_q8_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
@@ -1752,6 +1819,8 @@ class ggml_webgpu_shader_lib {
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.use_mmvq =
|
||||
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
|
||||
|
||||
auto it = mul_mat_vec_pipelines.find(key);
|
||||
if (it != mul_mat_vec_pipelines.end()) {
|
||||
@@ -1788,6 +1857,19 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
defines.push_back("SRC0_INNER_TYPE=u32");
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
if (key.use_mmvq) {
|
||||
defines.push_back("LEGACY_QUANTS");
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
if (key.use_mmvq) {
|
||||
defines.push_back("K_QUANTS");
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
@@ -1840,6 +1922,11 @@ class ggml_webgpu_shader_lib {
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
||||
}
|
||||
|
||||
if (key.use_mmvq) {
|
||||
defines.push_back("MMVQ");
|
||||
defines.push_back("Q8_1_T");
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
|
||||
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
|
||||
@@ -2018,100 +2105,6 @@ class ggml_webgpu_shader_lib {
|
||||
return mul_mat_fast_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_legacy_mul_mat_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
|
||||
auto it = mul_mat_legacy_pipelines.find(key);
|
||||
if (it != mul_mat_legacy_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "mul_mat";
|
||||
|
||||
switch (context.src1->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("SRC1_TYPE=f32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("SRC1_TYPE=f16");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported src1 type for mul_mat legacy shader");
|
||||
}
|
||||
|
||||
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
||||
const char * src0_name = src0_traits->type_name;
|
||||
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("SRC0_TYPE=f32");
|
||||
defines.push_back("FLOAT");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("SRC0_TYPE=f16");
|
||||
defines.push_back("FLOAT");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
{
|
||||
std::string type_upper = src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
{
|
||||
// Quantized types using u32 buffers for portability.
|
||||
defines.push_back("SRC0_TYPE=u32");
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back(type_upper + "_T");
|
||||
defines.push_back(type_upper);
|
||||
defines.push_back(type_upper + "_SCALE_MIN");
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
|
||||
variant += std::string("_") + src0_name;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_mul_mat, defines);
|
||||
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE;
|
||||
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
mul_mat_legacy_pipelines[key] = pipeline;
|
||||
return mul_mat_legacy_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
auto it = mul_mat_id_gather_pipelines.find(1);
|
||||
if (it != mul_mat_id_gather_pipelines.end()) {
|
||||
|
||||
@@ -181,6 +181,7 @@ struct webgpu_capabilities {
|
||||
wgpu::Limits limits;
|
||||
bool supports_subgroups = false;
|
||||
bool supports_subgroup_matrix = false;
|
||||
bool supports_dot_product = false;
|
||||
|
||||
uint32_t sg_mat_m = 0;
|
||||
uint32_t sg_mat_n = 0;
|
||||
@@ -210,6 +211,8 @@ struct webgpu_global_context_struct {
|
||||
wgpu::Buffer memset_params_buf;
|
||||
webgpu_pipeline memset_pipeline;
|
||||
|
||||
std::string vendor;
|
||||
|
||||
// TODO: We should rework the CPU profiling time handling to make it more useful. ref: https://github.com/ggml-org/llama.cpp/pull/22050
|
||||
#ifdef GGML_WEBGPU_CPU_PROFILE
|
||||
// Profiling: labeled CPU time in ms (total)
|
||||
@@ -259,6 +262,7 @@ struct webgpu_context_struct {
|
||||
wgpu::Buffer set_rows_host_error_buf;
|
||||
wgpu::CommandEncoder active_command_encoder;
|
||||
wgpu::ComputePassEncoder active_compute_pass;
|
||||
bool batch_compute_passes = true;
|
||||
|
||||
size_t memset_bytes_per_thread;
|
||||
|
||||
@@ -590,9 +594,18 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context &
|
||||
}
|
||||
#else
|
||||
for (size_t i = 0; i < dispatches.size(); i++) {
|
||||
ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline);
|
||||
ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]);
|
||||
ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1);
|
||||
if (ctx->batch_compute_passes) {
|
||||
ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline);
|
||||
ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]);
|
||||
ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second,
|
||||
1);
|
||||
} else {
|
||||
wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass();
|
||||
pass.SetPipeline(dispatches[i].pipeline.pipeline);
|
||||
pass.SetBindGroup(0, bind_groups[i]);
|
||||
pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1);
|
||||
pass.End();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -1384,6 +1397,58 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx,
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
std::vector<webgpu_dispatch_desc> & dispatches) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
|
||||
webgpu_pipeline qq8_pipeline = ctx->shader_lib->get_quantize_q8_pipeline(shader_lib_ctx);
|
||||
|
||||
// quantize_q8 pipeline
|
||||
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
|
||||
const size_t q8_src1_align_offset = ROUNDUP_POW2(
|
||||
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t q8_src1_binding_size =
|
||||
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
|
||||
std::vector<uint32_t> q8_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src1->ne[2],
|
||||
(uint32_t) src1->ne[3],
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> q8_entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1),
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), q8_src1_align_offset, q8_src1_binding_size)
|
||||
};
|
||||
|
||||
auto q8_decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(qq8_pipeline.context.get());
|
||||
|
||||
uint32_t q8_wg_size = q8_decisions->wg_size;
|
||||
uint32_t q8_wg_x = 1;
|
||||
uint32_t q8_wg_y = 1;
|
||||
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
|
||||
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
|
||||
|
||||
dispatches.push_back({
|
||||
qq8_pipeline, std::move(q8_params), std::move(q8_entries), { q8_wg_x, q8_wg_y }
|
||||
});
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
@@ -1391,47 +1456,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
// Determine if this is a mat-vec operation
|
||||
bool is_vec = (dst->ne[1] == 1);
|
||||
|
||||
// Determine if we should use fast path
|
||||
bool use_fast = false;
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
use_fast = (src0->type == GGML_TYPE_F16);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
// TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_MXFP4:
|
||||
use_fast = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
// use MMVQ path for mat-vec
|
||||
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
|
||||
ctx->global_ctx->vendor);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
|
||||
@@ -1446,16 +1473,20 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||
shader_lib_ctx.supports_dot_product = ctx->global_ctx->capabilities.supports_dot_product;
|
||||
shader_lib_ctx.vendor = ctx->global_ctx->vendor;
|
||||
|
||||
// Get or create pipeline
|
||||
webgpu_pipeline pipeline;
|
||||
webgpu_pipeline pipeline;
|
||||
std::vector<webgpu_dispatch_desc> dispatches;
|
||||
|
||||
if (use_fast && is_vec) {
|
||||
if (is_vec) {
|
||||
if (use_mmvq) {
|
||||
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
|
||||
}
|
||||
pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx);
|
||||
} else if (use_fast) {
|
||||
pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);
|
||||
} else {
|
||||
pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx);
|
||||
pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);
|
||||
}
|
||||
|
||||
// Build params
|
||||
@@ -1479,25 +1510,31 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
};
|
||||
|
||||
// Build bind group entries
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
|
||||
};
|
||||
std::vector<wgpu::BindGroupEntry> entries = {};
|
||||
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0));
|
||||
if (use_mmvq) {
|
||||
auto & mmvq_qq8_entry = dispatches[0].bind_group_entries[1];
|
||||
entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), mmvq_qq8_entry.offset,
|
||||
mmvq_qq8_entry.size));
|
||||
} else {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
|
||||
}
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
|
||||
|
||||
// Calculate workgroup dimensions
|
||||
uint32_t wg_x = 1;
|
||||
uint32_t wg_y = 1;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
if (use_fast && is_vec) {
|
||||
if (is_vec) {
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
|
||||
uint32_t total_wg = output_groups * batches;
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
} else if (use_fast) {
|
||||
} else {
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
// Fast-path tiled/subgroup calculations
|
||||
@@ -1518,15 +1555,13 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
}
|
||||
uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
|
||||
} else { // legacy
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_size = decisions->wg_size;
|
||||
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
}
|
||||
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
dispatches.push_back({
|
||||
pipeline, std::move(params), std::move(entries), { wg_x, wg_y }
|
||||
});
|
||||
|
||||
return ggml_backend_webgpu_build_multi(ctx, dispatches);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_mul_mat_id_vec(webgpu_context & ctx,
|
||||
@@ -1956,10 +1991,10 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
std::vector<wgpu::BindGroupEntry> reduce_entries;
|
||||
if (use_vec_reduce) {
|
||||
const uint32_t reduce_sg_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||
const uint32_t reduce_wg_size =
|
||||
std::max(reduce_sg_size, (uint32_t) std::min<uint64_t>(
|
||||
(uint64_t) nwg * reduce_sg_size,
|
||||
ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
|
||||
const uint32_t reduce_wg_size = std::max(
|
||||
reduce_sg_size,
|
||||
(uint32_t) std::min<uint64_t>((uint64_t) nwg * reduce_sg_size,
|
||||
ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
|
||||
ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx;
|
||||
reduce_shader_ctx.max_wg_size = reduce_wg_size;
|
||||
reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx);
|
||||
@@ -3110,18 +3145,16 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
||||
uint32_t num_batched_kernels = 0;
|
||||
uint32_t num_inflight_batches = 0;
|
||||
bool contains_set_rows = false;
|
||||
bool batch_compute_passes = true;
|
||||
int num_encoded_ops = 1;
|
||||
int node_idx = 0;
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
ctx->profile_timestamp_query_count = 0;
|
||||
batch_compute_passes = false;
|
||||
std::vector<std::string> profile_pipeline_names;
|
||||
#endif
|
||||
|
||||
ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder();
|
||||
if (batch_compute_passes) {
|
||||
if (ctx->batch_compute_passes) {
|
||||
ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass();
|
||||
}
|
||||
|
||||
@@ -3148,7 +3181,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
||||
|
||||
// reset state for next batch
|
||||
ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder();
|
||||
if (batch_compute_passes) {
|
||||
if (ctx->batch_compute_passes) {
|
||||
ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass();
|
||||
}
|
||||
ctx->param_arena.reset();
|
||||
@@ -3548,8 +3581,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
const uint32_t kv_tile = decisions.kv_tile;
|
||||
|
||||
const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
|
||||
uint32_t nwg = 1u;
|
||||
const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
|
||||
uint32_t nwg = 1u;
|
||||
const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
|
||||
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
|
||||
nwg <<= 1;
|
||||
}
|
||||
@@ -3582,6 +3615,22 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
}
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
const ggml_tensor * src0 = tensor->src[0];
|
||||
const ggml_tensor * src1 = tensor->src[1];
|
||||
bool use_mmvq =
|
||||
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
|
||||
ctx->webgpu_global_ctx->vendor);
|
||||
if (use_mmvq) {
|
||||
const size_t q8_src1_size =
|
||||
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
|
||||
res = ROUNDUP_POW2(res + q8_src1_size +
|
||||
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
const ggml_tensor * src0 = tensor->src[0];
|
||||
@@ -3707,12 +3756,16 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
ctx->webgpu_global_ctx->adapter.GetInfo(&info);
|
||||
ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size();
|
||||
ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches();
|
||||
ctx->webgpu_global_ctx->vendor = info.vendor;
|
||||
wgpu::SupportedFeatures features;
|
||||
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
|
||||
// we require f16 support
|
||||
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
||||
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
|
||||
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
|
||||
// for dot4I8packed
|
||||
ctx->webgpu_global_ctx->capabilities.supports_dot_product = ctx->webgpu_global_ctx->instance.HasWGSLLanguageFeature(
|
||||
wgpu::WGSLLanguageFeatureName::Packed4x8IntegerDotProduct);
|
||||
|
||||
bool valid_subgroup_matrix_config = false;
|
||||
#ifndef __EMSCRIPTEN__
|
||||
@@ -3839,6 +3892,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
webgpu_ctx->batch_compute_passes = false;
|
||||
ggml_webgpu_create_buffer(
|
||||
webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_dev_buf, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, "profile_timestamp_dev_buf");
|
||||
|
||||
@@ -95,11 +95,10 @@ struct q5_1 {
|
||||
};
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef Q8_1_T
|
||||
struct q8_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
s: f16, // d * sum(qs[i])
|
||||
qs: array<u32, 8>
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -1,747 +0,0 @@
|
||||
enable f16;
|
||||
|
||||
#define DECLARE_BYTE_LOADERS_SRC0
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
|
||||
#ifdef FLOAT
|
||||
const BLOCK_SIZE = 1u;
|
||||
|
||||
#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL)
|
||||
const BLOCK_SIZE = 32u;
|
||||
|
||||
#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS)
|
||||
const BLOCK_SIZE = 256u;
|
||||
#endif
|
||||
|
||||
#ifdef FLOAT
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q4_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_byte_offset = block_byte_base + 2 + j * 4;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
|
||||
let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
sum += q_lo * f32(src1[src1_offset]);
|
||||
sum += q_hi * f32(src1[src1_offset + 16]);
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q4_1
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q4_1 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q4_1.d);
|
||||
let m = f32(block_q4_1.m);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = block_q4_1.qs[j];
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
|
||||
let q_lo = f32(q_byte & 0xF) * d + m;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
sum += q_lo * f32(src1[src1_offset]);
|
||||
sum += q_hi * f32(src1[src1_offset + 16]);
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q5_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
let qh_packed = load_u32_at_src0(block_byte_base + 2);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_byte_offset = block_byte_base + 6 + j * 4;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
|
||||
let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
sum += q_lo * f32(src1[src1_offset]);
|
||||
sum += q_hi * f32(src1[src1_offset + 16]);
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q5_1
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q5_1 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q5_1.d);
|
||||
let m = f32(block_q5_1.m);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = block_q5_1.qs[j];
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
|
||||
let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
|
||||
let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
sum += q_lo * f32(src1[src1_offset]);
|
||||
sum += q_hi * f32(src1[src1_offset + 16]);
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q8_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let q_byte_offset = block_byte_base + 2 + j * 4;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
for (var k: u32 = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
sum += q_val * f32(src1[src1_offset]);
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q8_1
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q8_1 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q8_1.d);
|
||||
let m = f32(block_q8_1.m);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let q_packed = block_q8_1.qs[j];
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d + m;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
sum += q_val * f32(src1[src1_offset]);
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q2_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let m = f32(block.dmin);
|
||||
var sum = 0.0;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
// 2 halves of the block (128 elements each)
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
|
||||
// 4 groups (each group has 2 blocks of 16 elements)
|
||||
for (var shift: u32 = 0; shift < 8; shift += 2) {
|
||||
// 2 blocks
|
||||
for (var k: u32 = 0; k < 32; k += 16) {
|
||||
let sc = get_byte(block.scales[is / 4], is % 4);
|
||||
is++;
|
||||
let dl = d * f32(sc & 0xF);
|
||||
let ml = m * f32(sc >> 4);
|
||||
for (var l: u32 = 0u; l < 16; l++) {
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
|
||||
let qs_val = (q_byte >> shift) & 3;
|
||||
sum += (f32(qs_val) * dl - ml) * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q3_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
|
||||
|
||||
// Bytes 108-109: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base + 108);
|
||||
|
||||
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
|
||||
// and 2-bits from the last 4 bytes
|
||||
// Bytes 96-107: 12 bytes of scales (3 u32s)
|
||||
let kmask1: u32 = 0x03030303;
|
||||
let kmask2: u32 = 0x0f0f0f0f;
|
||||
var scale_vals: array<u32, 4>;
|
||||
scale_vals[0] = load_u32_at_src0(block_byte_base + 96);
|
||||
scale_vals[1] = load_u32_at_src0(block_byte_base + 100);
|
||||
scale_vals[2] = load_u32_at_src0(block_byte_base + 104);
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
||||
scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
||||
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
|
||||
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
||||
|
||||
// Bytes 0-31: 32 bytes of hmask (8 u32s)
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 32-95: 64 bytes of qs (16 u32s)
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0u; i < 16; i++) {
|
||||
qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4);
|
||||
}
|
||||
|
||||
var sum = 0.0;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
var m: u32 = 1;
|
||||
// 2 halves of the block (128 elements each)
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
|
||||
// 4 groups (each group has 2 blocks of 16 elements)
|
||||
for (var shift: u32 = 0; shift < 8; shift += 2) {
|
||||
// 2 blocks
|
||||
for (var k: u32 = 0; k < 32; k += 16) {
|
||||
let sc = get_byte(scale_vals[is / 4], is % 4);
|
||||
is++;
|
||||
let dl = d * (f32(sc) - 32.0);
|
||||
for (var l: u32 = 0u; l < 16u; l++) {
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let hm_idx = k + l;
|
||||
let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
|
||||
let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
|
||||
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
|
||||
let qs_val = (q_byte >> shift) & 3;
|
||||
sum += ((f32(qs_val) - hm) * dl) * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
m <<= 1;
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q4_K
|
||||
// 8 blocks of 32 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let m = f32(block.dmin);
|
||||
var sum = 0.0;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
// 2 blocks each iteration
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
|
||||
for (var shift: u32 = 0; shift < 8; shift += 4) {
|
||||
let scale_min = get_scale_min(is, block.scales);
|
||||
is++;
|
||||
let dl = d * scale_min.x;
|
||||
let ml = m * scale_min.y;
|
||||
for (var l: u32 = 0; l < 32; l++) {
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
|
||||
let qs_val = (q_byte >> shift) & 0xF;
|
||||
sum += (f32(qs_val) * dl - ml) * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q5_K
|
||||
// 8 blocks of 32 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let m = f32(block.dmin);
|
||||
var sum = 0.0;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
var u: u32 = 1;
|
||||
// 2 blocks each iteration
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
|
||||
for (var shift: u32 = 0; shift < 8; shift += 4) {
|
||||
let scale_min = get_scale_min(is, block.scales);
|
||||
is++;
|
||||
let dl = d * scale_min.x;
|
||||
let ml = m * scale_min.y;
|
||||
for (var l: u32 = 0; l < 32; l++) {
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
|
||||
let qh_byte = get_byte(block.qh[l / 4], l % 4);
|
||||
let qs_val = (q_byte >> shift) & 0xF;
|
||||
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
|
||||
sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
u <<= 1;
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q6_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes
|
||||
|
||||
// Bytes 208-209: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base + 208);
|
||||
|
||||
// Bytes 0-127: 128 bytes of ql (32 u32s)
|
||||
var ql_vals: array<u32, 32>;
|
||||
for (var i: u32 = 0; i < 32; i++) {
|
||||
ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 128-191: 64 bytes of qh (16 u32s)
|
||||
var qh_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 192-207: 16 bytes of scales (4 u32s)
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4);
|
||||
}
|
||||
|
||||
var sum = 0.0;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var qh_b_idx: u32 = 0;
|
||||
var sc_b_idx: u32 = 0;
|
||||
for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
|
||||
for (var l: u32 = 0; l < 32; l++) {
|
||||
let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);
|
||||
let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);
|
||||
let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);
|
||||
|
||||
let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
|
||||
let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
|
||||
let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
|
||||
let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;
|
||||
|
||||
let is = l/16;
|
||||
let is1 = sc_b_idx + is;
|
||||
let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);
|
||||
let is2 = sc_b_idx + is + 2;
|
||||
let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);
|
||||
let is3 = sc_b_idx + is + 4;
|
||||
let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);
|
||||
let is4 = sc_b_idx + is + 6;
|
||||
let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);
|
||||
|
||||
sum += d * f32(sc1) * q1 * src1[src1_i + l];
|
||||
sum += d * f32(sc2) * q2 * src1[src1_i + l + 32];
|
||||
sum += d * f32(sc3) * q3 * src1[src1_i + l + 64];
|
||||
sum += d * f32(sc4) * q4 * src1[src1_i + l + 96];
|
||||
}
|
||||
src1_i += 128;
|
||||
qh_b_idx += 32;
|
||||
sc_b_idx += 8;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_XXS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let aux0_offset = block_byte_base + 2 + ib * 2;
|
||||
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
|
||||
let aux0 = load_u32_at_src0(aux0_offset);
|
||||
let aux1 = load_u32_at_src0(aux1_offset);
|
||||
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = get_byte(aux0, l) * 8;
|
||||
let is = (aux1 >> (7 * l)) & 127;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
|
||||
sum += db * f32(g) * m * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_XS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var scale_vals = array<u32, 2>(
|
||||
load_u32_at_src0(block_byte_base + 66),
|
||||
load_u32_at_src0(block_byte_base + 70)
|
||||
);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
|
||||
let db = array<f32, 2>(
|
||||
d * (0.5 + f32(s & 0xF)) * 0.25,
|
||||
d * (0.5 + f32(s >> 4)) * 0.25
|
||||
);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
|
||||
let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF;
|
||||
let ig = (qs_val & 511) * 8;
|
||||
let is = qs_val >> 9;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
let dl = db[l/2];
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
|
||||
sum += dl * f32(g) * m * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var qs_vals : array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4);
|
||||
}
|
||||
|
||||
var qh_vals: array<u32, 2>;
|
||||
qh_vals[0] = load_u32_at_src0(block_byte_base + 66);
|
||||
qh_vals[1] = load_u32_at_src0(block_byte_base + 70);
|
||||
|
||||
var scale_vals: array<u32, 2>;
|
||||
scale_vals[0] = load_u32_at_src0(block_byte_base + 74);
|
||||
scale_vals[1] = load_u32_at_src0(block_byte_base + 78);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib ++) {
|
||||
let s = get_byte(scale_vals[ib / 4], ib % 4);
|
||||
let db = array<f32, 2>(
|
||||
d * (0.5 + f32(s & 0xF)) * 0.25,
|
||||
d * (0.5 + f32(s >> 4)) * 0.25
|
||||
);
|
||||
let qs_w = qs_vals[ib];
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;
|
||||
let ig = (get_byte(qs_w, l) | qh_b) * 8;
|
||||
let signs = get_byte(qs_vals[ib + 8], l);
|
||||
let dl = db[l/2];
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
|
||||
sum += dl * f32(g) * m * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ3_XXS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 16; ib += 2) {
|
||||
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
|
||||
let sc_sign = load_u32_at_src0(sc_sign_offset);
|
||||
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let is = (sc_sign >> (7 * l)) & 127;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0);
|
||||
let ig2 = get_byte(ig_val, 1);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let g1 = get_byte(iq3xxs_grid[ig1], j);
|
||||
let g2 = get_byte(iq3xxs_grid[ig2], j);
|
||||
let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
|
||||
let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
|
||||
sum += db * f32(g1) * m1 * src1[src1_i];
|
||||
sum += db * f32(g2) * m2 * src1[src1_i + 4];
|
||||
src1_i++;
|
||||
}
|
||||
src1_i += 4;
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ3_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var qh_vals = array<u32, 2>(
|
||||
load_u32_at_src0(block_byte_base + 66),
|
||||
load_u32_at_src0(block_byte_base + 70)
|
||||
);
|
||||
|
||||
var sign_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4);
|
||||
}
|
||||
|
||||
var scale_vals = load_u32_at_src0(block_byte_base + 106);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 4; ib++) {
|
||||
let s = get_byte(scale_vals, ib);
|
||||
let db = array<f32, 2>(
|
||||
d * (1.0 + 2.0 * f32(s & 0xF)),
|
||||
d * (1.0 + 2.0 * f32(s >> 4))
|
||||
);
|
||||
for (var k: u32 = 0; k < 2; k++) {
|
||||
let dl = db[k];
|
||||
let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);
|
||||
let sign_w = sign_vals[ib * 2 + k];
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let signs = get_byte(sign_w, l);
|
||||
let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
|
||||
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let g1 = get_byte(iq3s_grid[ig1], j);
|
||||
let g2 = get_byte(iq3s_grid[ig2], j);
|
||||
let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
|
||||
let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
|
||||
sum += dl * f32(g1) * m1 * src1[src1_i];
|
||||
sum += dl * f32(g2) * m2 * src1[src1_i + 4];
|
||||
src1_i++;
|
||||
}
|
||||
src1_i += 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ1_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF;
|
||||
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
|
||||
let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let gw = iq1_grid[(ig + j) / 16];
|
||||
let g = (gw >> (((ig + j) % 16) * 2)) & 3;
|
||||
let gs = bitcast<i32>(g << 30) >> 30;
|
||||
sum += dl * (f32(gs) + delta) * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef IQ1_M
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
|
||||
let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);
|
||||
let d = f32(bitcast<vec2<f16>>(scale).x);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;
|
||||
let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;
|
||||
let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;
|
||||
var dl = array<f32, 2>(
|
||||
d * f32(2 * s1 + 1),
|
||||
d * f32(2 * s2 + 1)
|
||||
);
|
||||
|
||||
let qh = block.qh[ib / 2] >> (16 * (ib % 2));
|
||||
var idx = array<u32, 4>(
|
||||
get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),
|
||||
get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),
|
||||
get_byte(block.qs[ib], 2) | ((qh) & 0x700),
|
||||
get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)
|
||||
);
|
||||
var delta = array<f32, 4>(
|
||||
select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),
|
||||
select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),
|
||||
select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),
|
||||
select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)
|
||||
);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = idx[l] * 8;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let gw = iq1_grid[(ig + j) / 16];
|
||||
let g = (gw >> (((ig + j) % 16) * 2)) & 3;
|
||||
let gs = bitcast<i32>(g << 30) >> 30;
|
||||
sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ4_NL
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 32;
|
||||
var sum = 0.0;
|
||||
var qs: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4);
|
||||
}
|
||||
for (var j: u32 = 0; j < 16; j++) {
|
||||
let qsb = get_byte(qs[j / 4], j % 4);
|
||||
sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
|
||||
sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
|
||||
src1_i++;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ4_XS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = unpack2x16float(block.d_scales_h)[0];
|
||||
let scales_h = block.d_scales_h >> 16;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
|
||||
let dl = d * (f32(ls) - 32.0);
|
||||
for (var j: u32 = 0; j < 16; j++) {
|
||||
let iqs = ib * 16 + j;
|
||||
let qsb = get_byte(block.qs[iqs / 4], iqs % 4);
|
||||
sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
|
||||
sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
|
||||
src1_i++;
|
||||
}
|
||||
src1_i += 16;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32, // in elements/blocks
|
||||
offset_src1: u32, // in elements/blocks
|
||||
offset_dst: u32, // in elements/blocks
|
||||
m: u32,
|
||||
n: u32,
|
||||
k: u32,
|
||||
// all strides are in elements/blocks
|
||||
stride_01: u32,
|
||||
stride_11: u32,
|
||||
stride_02: u32,
|
||||
stride_12: u32,
|
||||
stride_03: u32,
|
||||
stride_13: u32,
|
||||
|
||||
bs02: u32,
|
||||
bs03: u32,
|
||||
broadcast2: u32,
|
||||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn main(@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>) {
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
let global_idx = wg_linear * 256u + local_id.x;
|
||||
|
||||
let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
|
||||
if (global_idx >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
let dst2_stride = params.m * params.n;
|
||||
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
|
||||
|
||||
let dst3_idx = global_idx / dst3_stride;
|
||||
let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
|
||||
let src13_idx = dst3_idx; // src1 is not broadcast
|
||||
let dst3_rem = global_idx % dst3_stride;
|
||||
|
||||
let dst2_idx = dst3_rem / dst2_stride;
|
||||
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
|
||||
let src12_idx = dst2_idx; // src1 is not broadcast
|
||||
|
||||
let dst2_rem = dst3_rem % dst2_stride;
|
||||
|
||||
let row = dst2_rem / params.m; // output row
|
||||
let col = dst2_rem % params.m; // output column
|
||||
|
||||
let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
|
||||
|
||||
var sum = 0.0;
|
||||
for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {
|
||||
sum += multiply_add(src0_idx_base, src1_idx_base, i);
|
||||
}
|
||||
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
|
||||
}
|
||||
@@ -3,10 +3,18 @@ enable subgroups;
|
||||
#endif
|
||||
enable f16;
|
||||
|
||||
#ifdef MMVQ
|
||||
requires packed_4x8_integer_dot_product;
|
||||
#endif
|
||||
|
||||
#define DECLARE_BYTE_LOADERS_SRC0
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
#ifdef MMVQ
|
||||
#include "mul_mat_vec_q_acc.tmpl"
|
||||
#else
|
||||
#include "mul_mat_vec_acc.tmpl"
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32,
|
||||
@@ -28,9 +36,14 @@ struct MulMatParams {
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>;
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>;
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<f32>;
|
||||
|
||||
#ifdef MMVQ
|
||||
@group(0) @binding(1) var<storage, read_write> src1q: array<q8_1>;
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>;
|
||||
#endif
|
||||
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<f32>;
|
||||
// "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
@@ -75,10 +88,15 @@ fn main(
|
||||
let src12_idx = dst2_idx;
|
||||
|
||||
let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base;
|
||||
|
||||
#ifdef MMVQ
|
||||
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u);
|
||||
let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base);
|
||||
#else
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
|
||||
#endif
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
|
||||
@@ -436,7 +436,6 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef MUL_ACC_Q3_K
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 110
|
||||
|
||||
303
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl
Normal file
303
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl
Normal file
@@ -0,0 +1,303 @@
|
||||
#ifdef U32_DEQUANT_HELPERS
|
||||
#define SRC0_TYPE u32
|
||||
|
||||
fn byte_of(v: u32, b: u32) -> u32 {
|
||||
return (v >> (b * 8u)) & 0xFFu;
|
||||
}
|
||||
|
||||
fn sbyte_of(v: u32, b: u32) -> i32 {
|
||||
let raw = i32((v >> (b * 8u)) & 0xFFu);
|
||||
return select(raw, raw - 256, raw >= 128);
|
||||
}
|
||||
#endif
|
||||
|
||||
#define SRC0_TYPE SRC0_INNER_TYPE
|
||||
#define SRC1_TYPE SRC1_INNER_TYPE
|
||||
|
||||
#ifdef LEGACY_QUANTS
|
||||
#define BLOCK_SIZE 32
|
||||
#define THREADS_PER_BLOCK 4
|
||||
#elif K_QUANTS
|
||||
#define BLOCK_SIZE 256
|
||||
#define THREADS_PER_BLOCK 16
|
||||
#endif
|
||||
|
||||
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
|
||||
#define Q8_BLOCK_SIZE 32
|
||||
|
||||
#ifdef MUL_ACC_Q4_0
|
||||
#define BLOCK_SIZE_BYTES 18
|
||||
#define B_DS_TYPE vec2<f32>
|
||||
fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> {
|
||||
let qs_packed = load_u32_at_src0(block_byte_base + 2u + 4u * inner_id);
|
||||
|
||||
return vec2<u32>(
|
||||
qs_packed & 0x0F0F0F0Fu,
|
||||
(qs_packed >> 4u) & 0x0F0F0F0Fu
|
||||
);
|
||||
}
|
||||
fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> {
|
||||
return vec2<u32>(
|
||||
src1q[block].qs[inner_id],
|
||||
src1q[block].qs[inner_id + 4u],
|
||||
);
|
||||
}
|
||||
fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
return B_DS_TYPE(
|
||||
f32(src1q[block].d),
|
||||
f32(src1q[block].s)
|
||||
);
|
||||
}
|
||||
fn get_dm(block_byte_base: u32) -> f32 {
|
||||
return f32(load_f16_at_src0(block_byte_base));
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q4_1
|
||||
#define BLOCK_SIZE_BYTES 20
|
||||
#define B_DS_TYPE vec2<f32>
|
||||
fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> {
|
||||
let qs_packed = load_u32_at_src0(block_byte_base + 4u + 4u * inner_id);
|
||||
|
||||
return vec2<u32>(
|
||||
qs_packed & 0x0F0F0F0Fu,
|
||||
(qs_packed >> 4u) & 0x0F0F0F0Fu
|
||||
);
|
||||
}
|
||||
fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> {
|
||||
return vec2<u32>(
|
||||
src1q[block].qs[inner_id],
|
||||
src1q[block].qs[inner_id + 4u],
|
||||
);
|
||||
}
|
||||
fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
return B_DS_TYPE(
|
||||
f32(src1q[block].d),
|
||||
f32(src1q[block].s)
|
||||
);
|
||||
}
|
||||
fn get_dm(block_byte_base: u32) -> vec2<f32> {
|
||||
return vec2<f32>(
|
||||
f32(load_f16_at_src0(block_byte_base)),
|
||||
f32(load_f16_at_src0(block_byte_base + 2u))
|
||||
);
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q8_0
|
||||
#define BLOCK_SIZE_BYTES 34
|
||||
#define B_DS_TYPE f32
|
||||
fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> {
|
||||
return vec2<u32>(
|
||||
load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u)),
|
||||
load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u + 1))
|
||||
);
|
||||
}
|
||||
fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> {
|
||||
return vec2<u32>(
|
||||
src1q[block].qs[inner_id * 2u],
|
||||
src1q[block].qs[inner_id * 2u + 1],
|
||||
);
|
||||
}
|
||||
fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
return B_DS_TYPE(src1q[block].d);
|
||||
}
|
||||
fn get_dm(block_byte_base: u32) -> f32 {
|
||||
return f32(load_f16_at_src0(block_byte_base));
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (da * b_ds);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef LEGACY_QUANTS
|
||||
fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
var row_sum = 0;
|
||||
let a_repacked = repack_a(a_byte_base, b_inner_id);
|
||||
|
||||
row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]);
|
||||
row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]);
|
||||
|
||||
return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds);
|
||||
}
|
||||
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let b_inner_id = thread_id % THREADS_PER_BLOCK;
|
||||
let b_block_idx = src1q_idx_base + block;
|
||||
|
||||
let b_repacked = repack_b_qs(b_block_idx, b_inner_id);
|
||||
let b_ds = repack_b_dm(b_block_idx);
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q2_K
|
||||
#define BLOCK_SIZE_BYTES 84
|
||||
#define B_DS_TYPE f32
|
||||
fn repack_a(block_byte_base: u32, tid: u32) -> vec4<u32> {
|
||||
let ih2 = tid / 8u;
|
||||
let phase = tid % 2u;
|
||||
let iq4_idx = 2u * ih2 + phase;
|
||||
let qs_byte_base = block_byte_base + 16u + 16u * iq4_idx;
|
||||
let qs_shift = tid & 6u;
|
||||
return vec4<u32>(
|
||||
(load_u32_at_src0_aligned(qs_byte_base) >> qs_shift) & 0x03030303u,
|
||||
(load_u32_at_src0_aligned(qs_byte_base + 4u) >> qs_shift) & 0x03030303u,
|
||||
(load_u32_at_src0_aligned(qs_byte_base + 8u) >> qs_shift) & 0x03030303u,
|
||||
(load_u32_at_src0_aligned(qs_byte_base + 12u) >> qs_shift) & 0x03030303u,
|
||||
);
|
||||
}
|
||||
fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4<u32> {
|
||||
let phase = tid % 2u;
|
||||
return vec4<u32>(
|
||||
src1q[q8_block_idx].qs[4u * phase],
|
||||
src1q[q8_block_idx].qs[4u * phase + 1u],
|
||||
src1q[q8_block_idx].qs[4u * phase + 2u],
|
||||
src1q[q8_block_idx].qs[4u * phase + 3u],
|
||||
);
|
||||
}
|
||||
fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE {
|
||||
return B_DS_TYPE(src1q[q8_block_idx].d);
|
||||
}
|
||||
fn get_dm(block_byte_base: u32) -> vec2<f32> {
|
||||
return vec2<f32>(
|
||||
f32(load_f16_at_src0(block_byte_base + 80u)),
|
||||
f32(load_f16_at_src0(block_byte_base + 82u)),
|
||||
);
|
||||
}
|
||||
fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
|
||||
let scale_byte = block_byte_base + tid;
|
||||
let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u);
|
||||
return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u));
|
||||
}
|
||||
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
let a_repacked = repack_a(a_byte_base, tid);
|
||||
let dm = get_dm(a_byte_base);
|
||||
let scale_min = get_scale_min(a_byte_base, tid);
|
||||
|
||||
let scale_q = i32(scale_min.x);
|
||||
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
|
||||
|
||||
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
|
||||
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
|
||||
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
|
||||
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
|
||||
|
||||
return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q4_K
|
||||
#define BLOCK_SIZE_BYTES 144
|
||||
#define B_DS_TYPE vec2<f32>
|
||||
fn repack_a(block_byte_base: u32, tid: u32) -> vec4<u32> {
|
||||
let iq4 = tid / 4u;
|
||||
let phase = tid % 2u;
|
||||
let nibble = (tid >> 1u) % 2u;
|
||||
let q_qs_byte_base = block_byte_base + 16u + 32u * iq4 + 16u * phase;
|
||||
let qs_shift = 4u * nibble;
|
||||
return vec4<u32>(
|
||||
(load_u32_at_src0_aligned(q_qs_byte_base) >> qs_shift) & 0x0F0F0F0Fu,
|
||||
(load_u32_at_src0_aligned(q_qs_byte_base + 4u) >> qs_shift) & 0x0F0F0F0Fu,
|
||||
(load_u32_at_src0_aligned(q_qs_byte_base + 8u) >> qs_shift) & 0x0F0F0F0Fu,
|
||||
(load_u32_at_src0_aligned(q_qs_byte_base + 12u) >> qs_shift) & 0x0F0F0F0Fu,
|
||||
);
|
||||
}
|
||||
fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4<u32> {
|
||||
let phase = tid % 2u;
|
||||
return vec4<u32>(
|
||||
src1q[q8_block_idx].qs[4u * phase],
|
||||
src1q[q8_block_idx].qs[4u * phase + 1u],
|
||||
src1q[q8_block_idx].qs[4u * phase + 2u],
|
||||
src1q[q8_block_idx].qs[4u * phase + 3u],
|
||||
);
|
||||
}
|
||||
fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE {
|
||||
return B_DS_TYPE(
|
||||
f32(src1q[q8_block_idx].d),
|
||||
f32(src1q[q8_block_idx].s),
|
||||
);
|
||||
}
|
||||
fn get_dm(block_byte_base: u32) -> vec2<f32> {
|
||||
return vec2<f32>(
|
||||
f32(load_f16_at_src0(block_byte_base + 0u)),
|
||||
f32(load_f16_at_src0(block_byte_base + 2u)),
|
||||
);
|
||||
}
|
||||
fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
|
||||
let sc_m_idx = tid / 2u;
|
||||
let scales_byte_base = block_byte_base + 4u;
|
||||
let scales0_3 = load_u32_at_src0_aligned(scales_byte_base);
|
||||
let scales4_7 = load_u32_at_src0_aligned(scales_byte_base + 4u);
|
||||
let scales8_11 = load_u32_at_src0_aligned(scales_byte_base + 8u);
|
||||
|
||||
let byte_idx = sc_m_idx & 3u;
|
||||
let is_high = sc_m_idx >= 4u;
|
||||
|
||||
let sc_low = byte_of(scales0_3, byte_idx) & 0x3Fu;
|
||||
let sc_high = (byte_of(scales8_11, byte_idx) & 0x0Fu) | ((byte_of(scales0_3, byte_idx) & 0xC0u) >> 2u);
|
||||
let scale = f32(select(sc_low, sc_high, is_high));
|
||||
|
||||
let mn_low = byte_of(scales4_7, byte_idx) & 0x3Fu;
|
||||
let mn_high = (byte_of(scales8_11, byte_idx) >> 4u) | ((byte_of(scales4_7, byte_idx) & 0xC0u) >> 2u);
|
||||
let min_val = f32(select(mn_low, mn_high, is_high));
|
||||
|
||||
return vec2<f32>(scale, min_val);
|
||||
}
|
||||
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
let a_repacked = repack_a(a_byte_base, tid);
|
||||
let dm = get_dm(a_byte_base);
|
||||
let scale_min = get_scale_min(a_byte_base, tid);
|
||||
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
|
||||
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
|
||||
|
||||
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
|
||||
return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef K_QUANTS
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
|
||||
let b_repacked = repack_b_qs(src1q_idx, tid);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
173
ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl
Normal file
173
ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl
Normal file
@@ -0,0 +1,173 @@
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
enable subgroups;
|
||||
#endif
|
||||
enable f16;
|
||||
|
||||
requires packed_4x8_integer_dot_product;
|
||||
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
struct Params {
|
||||
offset_src1: u32,
|
||||
stride_12: u32,
|
||||
stride_13: u32,
|
||||
ne0: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
};
|
||||
|
||||
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src1: array<SRC1_TYPE>;
|
||||
@group(0) @binding(1) var<storage, read_write> src1q: array<q8_1>;
|
||||
|
||||
@group(0) @binding(2) var<uniform> params: Params;
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
fn cluster_max_8(v: f32) -> f32 {
|
||||
var r = v;
|
||||
r = max(r, subgroupShuffleXor(r, 1u));
|
||||
r = max(r, subgroupShuffleXor(r, 2u));
|
||||
r = max(r, subgroupShuffleXor(r, 4u));
|
||||
return r;
|
||||
}
|
||||
|
||||
#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K)
|
||||
fn cluster_add_i4x8(v: i32) -> i32 {
|
||||
var r= v;
|
||||
r += subgroupShuffleXor(r, 1u);
|
||||
r += subgroupShuffleXor(r, 2u);
|
||||
r += subgroupShuffleXor(r, 4u);
|
||||
return r;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
#define CLUSTER_SIZE 8
|
||||
|
||||
var<workgroup> partial_amaxs: array<array<f32, CLUSTER_SIZE>, WG_SIZE / CLUSTER_SIZE>;
|
||||
var<workgroup> partial_sums: array<array<i32, CLUSTER_SIZE>, WG_SIZE / CLUSTER_SIZE>;
|
||||
#endif
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>
|
||||
) {
|
||||
let thread_id = local_id.x;
|
||||
let num_vec4 = params.ne0 / 4u;
|
||||
|
||||
let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
|
||||
let total_batches = wg_per_vec * params.ne2 * params.ne3;
|
||||
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
if (wg_linear >= total_batches) {
|
||||
return;
|
||||
}
|
||||
|
||||
let src13_idx = wg_linear / (params.ne2 * wg_per_vec);
|
||||
let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec;
|
||||
let src11_wg_idx = wg_linear % wg_per_vec;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let src1_idx_vec4_base = src1_idx_base / 4u;
|
||||
|
||||
let blocks_per_row = params.ne0 / 32u;
|
||||
let blocks_per_wg = (WG_SIZE * 4u) / 32u;
|
||||
let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row;
|
||||
let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u;
|
||||
let qs_idx = thread_id % 8u;
|
||||
|
||||
// reduction
|
||||
var q4 = vec4<f32>(0.0);
|
||||
var q4_quants = 0u;
|
||||
var thread_amax = 0.0;
|
||||
|
||||
let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id;
|
||||
let is_valid = src11_vec4_idx < num_vec4;
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
|
||||
var d = 0.0;
|
||||
|
||||
if (is_valid) {
|
||||
q4 = src1[src1_idx_vec4_base + src11_vec4_idx];
|
||||
let abs_q4 = abs(q4);
|
||||
thread_amax = max(max(abs_q4[0u], abs_q4[1u]), max(abs_q4[2], abs_q4[3]));
|
||||
}
|
||||
|
||||
d = cluster_max_8(thread_amax) / 127.0;
|
||||
|
||||
if (is_valid) {
|
||||
let id = select(0.0, 1.0 / d, d > 0.0);
|
||||
q4_quants = pack4xI8(vec4<i32>(round(q4 * id)));
|
||||
if (qs_idx == 0u) {
|
||||
src1q[src1q_idx].d = f16(d);
|
||||
}
|
||||
src1q[src1q_idx].qs[qs_idx] = q4_quants;
|
||||
}
|
||||
|
||||
#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K)
|
||||
let q4_quants_sum = dot4I8Packed(q4_quants, 0x01010101u);
|
||||
let s = f16(d * f32(cluster_add_i4x8(q4_quants_sum)));
|
||||
|
||||
if (is_valid) {
|
||||
if (qs_idx == 0u) {
|
||||
src1q[src1q_idx].s = s;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
|
||||
var d = 0.0;
|
||||
let cluster_id = thread_id / 8u;
|
||||
|
||||
if (is_valid) {
|
||||
q4 = src1[src1_idx_vec4_base + src11_vec4_idx];
|
||||
let abs_q4 = abs(q4);
|
||||
thread_amax = max(max(abs_q4[0], abs_q4[1]), max(abs_q4[2], abs_q4[3]));
|
||||
partial_amaxs[cluster_id][qs_idx] = thread_amax;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (is_valid) {
|
||||
let amax = max(
|
||||
max(
|
||||
max(partial_amaxs[cluster_id][0], partial_amaxs[cluster_id][1]), max(partial_amaxs[cluster_id][2], partial_amaxs[cluster_id][3])),
|
||||
max(
|
||||
max(partial_amaxs[cluster_id][4], partial_amaxs[cluster_id][5]), max(partial_amaxs[cluster_id][6], partial_amaxs[cluster_id][7]))
|
||||
);
|
||||
|
||||
d = amax / 127.0;
|
||||
let id = select(0.0f, 1.0f / d, d > 0.0f);
|
||||
|
||||
q4_quants = pack4xI8(vec4<i32>(round(q4 * id)));
|
||||
src1q[src1q_idx].qs[qs_idx] = q4_quants;
|
||||
|
||||
if (qs_idx == 0u) {
|
||||
src1q[src1q_idx].d = f16(d);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K)
|
||||
|
||||
partial_sums[cluster_id][qs_idx] = dot4I8Packed(q4_quants, 0x01010101u);
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (is_valid) {
|
||||
if (qs_idx == 0u) {
|
||||
let s = d * f32(partial_sums[cluster_id][0] + partial_sums[cluster_id][1] + partial_sums[cluster_id][2] + partial_sums[cluster_id][3]
|
||||
+ partial_sums[cluster_id][4] + partial_sums[cluster_id][5] + partial_sums[cluster_id][6] + partial_sums[cluster_id][7]);
|
||||
src1q[src1q_idx].s = f16(s);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
}
|
||||
@@ -505,6 +505,7 @@ class MODEL_ARCH(IntEnum):
|
||||
LLAMA_EMBED = auto()
|
||||
MAINCODER = auto()
|
||||
KIMI_LINEAR = auto()
|
||||
TALKIE = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
@@ -1021,6 +1022,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.LLAMA_EMBED: "llama-embed",
|
||||
MODEL_ARCH.MAINCODER: "maincoder",
|
||||
MODEL_ARCH.KIMI_LINEAR: "kimi-linear",
|
||||
MODEL_ARCH.TALKIE: "talkie",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
@@ -4013,6 +4015,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.TALKIE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ class TensorNameMap:
|
||||
"encoder", # neobert
|
||||
"model.transformer.wte", # llada
|
||||
"embed_tokens", # qwen3-embedding
|
||||
"model.embed", # talkie
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
@@ -259,6 +260,7 @@ class TensorNameMap:
|
||||
"model.transformer.blocks.{bid}.q_proj", # llada
|
||||
"layers.{bid}.self_attn.q_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.q_proj", # nemotron-h
|
||||
"model.blocks.{bid}.attn.attn_query", # talkie
|
||||
),
|
||||
|
||||
# Attention key
|
||||
@@ -279,6 +281,7 @@ class TensorNameMap:
|
||||
"model.transformer.blocks.{bid}.k_proj", # llada
|
||||
"layers.{bid}.self_attn.k_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.k_proj", # nemotron-h
|
||||
"model.blocks.{bid}.attn.attn_key", # talkie
|
||||
),
|
||||
|
||||
# Attention value
|
||||
@@ -298,6 +301,7 @@ class TensorNameMap:
|
||||
"model.transformer.blocks.{bid}.v_proj", # llada
|
||||
"layers.{bid}.self_attn.v_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.v_proj", # nemotron-h
|
||||
"model.blocks.{bid}.attn.attn_value", # talkie
|
||||
),
|
||||
|
||||
# Attention output
|
||||
@@ -336,6 +340,7 @@ class TensorNameMap:
|
||||
"layers.{bid}.self_attn.o_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.o_proj", # nemotron-h
|
||||
"model.layers.{bid}.self_attn.language_expert_dense", # cogvlm
|
||||
"model.blocks.{bid}.attn.attn_resid", # talkie
|
||||
),
|
||||
|
||||
# Attention output norm
|
||||
@@ -508,6 +513,7 @@ class TensorNameMap:
|
||||
"layers.{bid}.mlp.up_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.up_proj", # nemotron-h
|
||||
"model.layers.{bid}.mlp.language_mlp.up_proj", # cogvlm
|
||||
"model.blocks.{bid}.mlp.mlp_linear", # talkie
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
@@ -561,6 +567,7 @@ class TensorNameMap:
|
||||
"model.transformer.blocks.{bid}.ff_proj", # llada
|
||||
"layers.{bid}.mlp.gate_proj", # qwen3-embedding
|
||||
"model.layers.{bid}.mlp.language_mlp.gate_proj", # cogvlm
|
||||
"model.blocks.{bid}.mlp.mlp_gate", # talkie
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
@@ -636,6 +643,7 @@ class TensorNameMap:
|
||||
"layers.{bid}.mlp.down_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.down_proj", # nemotron-h
|
||||
"model.layers.{bid}.mlp.language_mlp.down_proj", # cogvlm
|
||||
"model.blocks.{bid}.mlp.mlp_resid", # talkie
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
@@ -682,6 +690,7 @@ class TensorNameMap:
|
||||
"model.layers.layers.{bid}.mixer.q_norm", # plamo3
|
||||
"layers.{bid}.self_attn.q_norm", # qwen3-embedding
|
||||
"model.layers.{bid}.attention.query_layernorm", # apertus
|
||||
"model.blocks.{bid}.attn.head_gain.head_g", # talkie
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_K_NORM: (
|
||||
@@ -716,6 +725,7 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE: (
|
||||
"model.layers.{bid}.layer_scalar", # gemma4
|
||||
"model.blocks.{bid}.embed_skip.a_g", # talkie
|
||||
),
|
||||
|
||||
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
|
||||
|
||||
@@ -133,6 +133,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_LLAMA_EMBED, "llama-embed" },
|
||||
{ LLM_ARCH_MAINCODER, "maincoder" },
|
||||
{ LLM_ARCH_KIMI_LINEAR, "kimi-linear" },
|
||||
{ LLM_ARCH_TALKIE, "talkie" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
||||
@@ -137,6 +137,7 @@ enum llm_arch {
|
||||
LLM_ARCH_LLAMA_EMBED,
|
||||
LLM_ARCH_MAINCODER,
|
||||
LLM_ARCH_KIMI_LINEAR,
|
||||
LLM_ARCH_TALKIE,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
||||
@@ -44,6 +44,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
|
||||
return new llama_model_llama_embed(params);
|
||||
case LLM_ARCH_MAINCODER:
|
||||
return new llama_model_maincoder(params);
|
||||
case LLM_ARCH_TALKIE:
|
||||
return new llama_model_talkie(params);
|
||||
case LLM_ARCH_DECI:
|
||||
return new llama_model_deci(params);
|
||||
case LLM_ARCH_BAICHUAN:
|
||||
@@ -2353,6 +2355,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_QWEN3NEXT:
|
||||
case LLM_ARCH_MIMO2:
|
||||
case LLM_ARCH_STEP35:
|
||||
case LLM_ARCH_TALKIE:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
||||
@@ -488,7 +488,7 @@ struct llama_layer {
|
||||
struct ggml_tensor * indexer_attn_k = nullptr;
|
||||
struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias
|
||||
|
||||
// gemma4 layer output scale
|
||||
// gemma4 layer output scale, reused for talkie embedding skip scale
|
||||
struct ggml_tensor * out_scale = nullptr;
|
||||
|
||||
struct llama_layer_posnet posnet;
|
||||
|
||||
@@ -2196,7 +2196,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
} else if (
|
||||
tokenizer_pre == "gpt-4o" ||
|
||||
tokenizer_pre == "llama4" ||
|
||||
tokenizer_pre == "kanana2") {
|
||||
tokenizer_pre == "kanana2" ||
|
||||
tokenizer_pre == "talkie") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
|
||||
@@ -177,9 +177,9 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s,
|
||||
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, model.layers[il].ffn_gate_s,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
@@ -200,7 +200,11 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa
|
||||
LLM_FFN_SILU, true,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
il,
|
||||
nullptr, nullptr,
|
||||
model.layers[il].ffn_up_exps_s,
|
||||
model.layers[il].ffn_gate_exps_s,
|
||||
model.layers[il].ffn_down_exps_s);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
@@ -186,6 +186,19 @@ struct llama_model_maincoder : public llama_model_base {
|
||||
};
|
||||
|
||||
|
||||
struct llama_model_talkie : public llama_model_base {
|
||||
llama_model_talkie(const struct llama_model_params & params) : llama_model_base(params) {}
|
||||
void load_arch_hparams(llama_model_loader & ml) override;
|
||||
void load_arch_tensors(llama_model_loader & ml) override;
|
||||
|
||||
struct graph : public llm_graph_context {
|
||||
graph(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
|
||||
};
|
||||
|
||||
|
||||
struct llama_model_deci : public llama_model_base {
|
||||
llama_model_deci(const struct llama_model_params & params) : llama_model_base(params) {}
|
||||
void load_arch_hparams(llama_model_loader & ml) override;
|
||||
|
||||
149
src/models/talkie.cpp
Normal file
149
src/models/talkie.cpp
Normal file
@@ -0,0 +1,149 @@
|
||||
#include "models.h"
|
||||
|
||||
void llama_model_talkie::load_arch_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 40: type = LLM_TYPE_13B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_talkie::load_arch_tensors(llama_model_loader &) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
// no k gain
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {1, n_head}, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
|
||||
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1}, 0);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_talkie::build_arch_graph(const llm_graph_params & params) const {
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
llama_model_talkie::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
inpL = build_norm(inpL, nullptr, nullptr, LLM_NORM_RMS, -1);
|
||||
cb(inpL, "inp_norm", -1);
|
||||
|
||||
ggml_tensor * embd_skip = inpL;
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
ggml_tensor * inp_skip = embd_skip;
|
||||
|
||||
cur = build_norm(inpL, nullptr, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur,
|
||||
n_embd_head, n_head, n_head_kv, il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
// reference applies qknorm after rope
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_norm", il);
|
||||
|
||||
Kcur = build_norm(Kcur, nullptr, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_norm", il);
|
||||
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, nullptr, model.layers[il].wo_s,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
inp_skip = ggml_get_rows(ctx0, inp_skip, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
cur = build_norm(ffn_inp, nullptr, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, nullptr, nullptr,
|
||||
model.layers[il].ffn_gate, nullptr, nullptr,
|
||||
model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s,
|
||||
nullptr,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
ggml_tensor * skip = ggml_mul(ctx0, inp_skip, model.layers[il].out_scale);
|
||||
cb(skip, "embd_skip", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, skip);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur, nullptr, nullptr, LLM_NORM_RMS, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
res->t_embd = cur;
|
||||
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
cur = ggml_scale(ctx0, cur, hparams.f_logit_scale);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
@@ -21,6 +21,7 @@
|
||||
#include <ggml-cpp.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <array>
|
||||
#include <cfloat>
|
||||
#include <cinttypes>
|
||||
@@ -33,6 +34,7 @@
|
||||
#include <future>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
@@ -55,33 +57,24 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
|
||||
{
|
||||
// parallel initialization
|
||||
static const size_t n_threads = N_THREADS;
|
||||
// static RNG initialization (revisit if n_threads stops being constant)
|
||||
static std::vector<std::default_random_engine> generators = []() {
|
||||
std::random_device rd;
|
||||
std::vector<std::default_random_engine> vec;
|
||||
vec.reserve(n_threads);
|
||||
//for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed
|
||||
for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }
|
||||
return vec;
|
||||
}();
|
||||
|
||||
auto init_thread = [&](size_t ith, size_t start, size_t end) {
|
||||
auto init_thread = [&](size_t start, size_t end) {
|
||||
thread_local std::default_random_engine gen(std::random_device{}());
|
||||
std::uniform_real_distribution<float> distribution(min, max);
|
||||
auto & gen = generators[ith];
|
||||
for (size_t i = start; i < end; i++) {
|
||||
data[i] = distribution(gen);
|
||||
}
|
||||
};
|
||||
|
||||
if (n_threads == 1) {
|
||||
init_thread(0, 0, nels);
|
||||
init_thread(0, nels);
|
||||
} else {
|
||||
std::vector<std::future<void>> tasks;
|
||||
tasks.reserve(n_threads);
|
||||
for (size_t i = 0; i < n_threads; i++) {
|
||||
size_t start = i*nels/n_threads;
|
||||
size_t end = (i+1)*nels/n_threads;
|
||||
tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));
|
||||
tasks.push_back(std::async(std::launch::async, init_thread, start, end));
|
||||
}
|
||||
for (auto & t : tasks) {
|
||||
t.get();
|
||||
@@ -516,6 +509,25 @@ static bool output_format_from_str(const std::string & s, output_formats & forma
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::string test_time_now() {
|
||||
time_t t = time(NULL);
|
||||
struct tm tm_buf;
|
||||
#ifdef _WIN32
|
||||
if (gmtime_s(&tm_buf, &t) != 0) {
|
||||
return "";
|
||||
}
|
||||
#else
|
||||
if (gmtime_r(&t, &tm_buf) == nullptr) {
|
||||
return "";
|
||||
}
|
||||
#endif
|
||||
char buf[32];
|
||||
if (std::strftime(buf, sizeof(buf), "%FT%TZ", &tm_buf) == 0) {
|
||||
return "";
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
|
||||
// Test result structure for SQL output
|
||||
struct test_result {
|
||||
std::string test_time;
|
||||
@@ -545,11 +557,7 @@ struct test_result {
|
||||
supported = false;
|
||||
passed = false;
|
||||
|
||||
// Set test time
|
||||
time_t t = time(NULL);
|
||||
char buf[32];
|
||||
std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
|
||||
test_time = buf;
|
||||
test_time = test_time_now();
|
||||
|
||||
// Set build info
|
||||
build_commit = ggml_commit();
|
||||
@@ -573,11 +581,7 @@ struct test_result {
|
||||
n_runs(n_runs),
|
||||
device_description(device_description),
|
||||
backend_reg_name(backend_reg_name) {
|
||||
// Set test time
|
||||
time_t t = time(NULL);
|
||||
char buf[32];
|
||||
std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
|
||||
test_time = buf;
|
||||
test_time = test_time_now();
|
||||
|
||||
// Set build info
|
||||
build_commit = ggml_commit();
|
||||
@@ -1110,6 +1114,17 @@ static std::unique_ptr<printer> create_printer(output_formats format) {
|
||||
GGML_ABORT("invalid output format");
|
||||
}
|
||||
|
||||
static std::mutex g_test_output_mutex;
|
||||
|
||||
static void print_test_result_locked(printer * output_printer, const test_result & result) {
|
||||
if (output_printer == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> guard(g_test_output_mutex);
|
||||
output_printer->print_test_result(result);
|
||||
}
|
||||
|
||||
struct test_case {
|
||||
virtual ~test_case() {}
|
||||
|
||||
@@ -1338,9 +1353,7 @@ struct test_case {
|
||||
test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test",
|
||||
false, false, "not supported");
|
||||
|
||||
if (output_printer) {
|
||||
output_printer->print_test_result(result);
|
||||
}
|
||||
print_test_result_locked(output_printer, result);
|
||||
|
||||
ggml_free(ctx);
|
||||
return test_status_t::NOT_SUPPORTED;
|
||||
@@ -1462,9 +1475,7 @@ struct test_case {
|
||||
test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test", supported, test_passed,
|
||||
error_msg);
|
||||
|
||||
if (output_printer) {
|
||||
output_printer->print_test_result(result);
|
||||
}
|
||||
print_test_result_locked(output_printer, result);
|
||||
|
||||
return test_passed ? test_status_t::OK : test_status_t::FAIL;
|
||||
}
|
||||
@@ -9493,8 +9504,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_from_file(const c
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter,
|
||||
printer * output_printer, const char * test_file_path) {
|
||||
static bool test_backend(ggml_backend_t backend, ggml_backend_dev_t dev, test_mode mode, const char * op_names_filter, const char * params_filter,
|
||||
printer * output_printer, const char * test_file_path, int parallel_workers) {
|
||||
auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
|
||||
if (params_filter == nullptr) {
|
||||
return;
|
||||
@@ -9547,21 +9558,90 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
set_use_ref(backend_cpu, true);
|
||||
}
|
||||
|
||||
size_t n_ok = 0;
|
||||
size_t tests_run = 0;
|
||||
std::atomic<size_t> n_ok = 0;
|
||||
std::atomic<size_t> tests_run = 0;
|
||||
std::vector<std::string> failed_tests;
|
||||
for (auto & test : test_cases) {
|
||||
test_status_t status = test->eval(backend, backend_cpu, op_names_filter, output_printer);
|
||||
if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) {
|
||||
continue;
|
||||
std::mutex failed_tests_mutex;
|
||||
|
||||
// Each worker grabs a chunk of cases at a time. The chunk shrinks as we
|
||||
// run out of work so that a few slow tests at the tail get spread across
|
||||
// workers instead of landing on one unlucky thread.
|
||||
constexpr size_t MAX_TESTS_PER_ITER = 100;
|
||||
std::atomic<size_t> test_idx = 0;
|
||||
|
||||
const auto & next_chunk = [&](size_t & my_begin, size_t & my_end) {
|
||||
const size_t cur = test_idx.load(std::memory_order_relaxed);
|
||||
const size_t remaining = cur < test_cases.size() ? test_cases.size() - cur : 0;
|
||||
const size_t chunk = std::max<size_t>(1, std::min<size_t>(MAX_TESTS_PER_ITER, remaining / parallel_workers));
|
||||
my_begin = test_idx.fetch_add(chunk);
|
||||
my_end = std::min(my_begin + chunk, test_cases.size());
|
||||
};
|
||||
|
||||
const auto & run_tests = [&](ggml_backend_t b, ggml_backend_t b_cpu) {
|
||||
size_t my_begin, my_end;
|
||||
next_chunk(my_begin, my_end);
|
||||
while (my_begin < test_cases.size()) {
|
||||
for (size_t i = my_begin; i < my_end; ++i) {
|
||||
auto & test = test_cases[i];
|
||||
test_status_t status = test->eval(b, b_cpu, op_names_filter, output_printer);
|
||||
if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) {
|
||||
continue;
|
||||
}
|
||||
tests_run++;
|
||||
if (status == test_status_t::OK) {
|
||||
n_ok++;
|
||||
} else if (status == test_status_t::FAIL) {
|
||||
std::lock_guard<std::mutex> guard(failed_tests_mutex);
|
||||
failed_tests.push_back(test->current_op_name + "(" + test->vars() + ")");
|
||||
}
|
||||
}
|
||||
next_chunk(my_begin, my_end);
|
||||
}
|
||||
tests_run++;
|
||||
if (status == test_status_t::OK) {
|
||||
n_ok++;
|
||||
} else if (status == test_status_t::FAIL) {
|
||||
failed_tests.push_back(test->current_op_name + "(" + test->vars() + ")");
|
||||
};
|
||||
|
||||
if (parallel_workers <= 1) {
|
||||
// Reuse the outer backend / backend_cpu so we don't pay an
|
||||
// extra CPU backend init.
|
||||
run_tests(backend, backend_cpu);
|
||||
} else {
|
||||
std::atomic<size_t> workers_started = 0;
|
||||
|
||||
const auto & eval_worker = [&]() {
|
||||
ggml_backend_t b = ggml_backend_dev_init(dev, NULL);
|
||||
if (b == NULL) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_t b_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
|
||||
if (b_cpu == NULL) {
|
||||
ggml_backend_free(b);
|
||||
return;
|
||||
}
|
||||
|
||||
if (set_use_ref) {
|
||||
set_use_ref(b_cpu, true);
|
||||
}
|
||||
workers_started++;
|
||||
run_tests(b, b_cpu);
|
||||
ggml_backend_free(b_cpu);
|
||||
ggml_backend_free(b);
|
||||
};
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(parallel_workers);
|
||||
for (int i = 0; i < parallel_workers; ++i) {
|
||||
threads.emplace_back(eval_worker);
|
||||
}
|
||||
for (auto & t : threads) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
if (workers_started == 0 && !test_cases.empty()) {
|
||||
ggml_backend_free(backend_cpu);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
output_printer->print_summary(test_summary_info(n_ok, tests_run, false));
|
||||
output_printer->print_failed_tests(failed_tests);
|
||||
|
||||
@@ -9709,7 +9789,7 @@ static void show_test_coverage() {
|
||||
|
||||
static void usage(char ** argv) {
|
||||
printf("Usage: %s [mode] [-o <op,..>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>] [--list-ops]", argv[0]);
|
||||
printf(" [--show-coverage] [--test-file <path>]\n");
|
||||
printf(" [--show-coverage] [--test-file <path>] [-j <n>]\n");
|
||||
printf(" valid modes:\n");
|
||||
printf(" - test (default, compare with CPU backend for correctness)\n");
|
||||
printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
|
||||
@@ -9721,6 +9801,7 @@ static void usage(char ** argv) {
|
||||
printf(" --list-ops lists all available GGML operations\n");
|
||||
printf(" --show-coverage shows test coverage\n");
|
||||
printf(" --test-file reads test operators from a test file generated by llama-export-graph-ops\n");
|
||||
printf(" -j <n> runs tests using <n> parallel worker threads (default: 1, test mode only)\n");
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
@@ -9730,6 +9811,7 @@ int main(int argc, char ** argv) {
|
||||
const char * backend_filter = nullptr;
|
||||
const char * params_filter = nullptr;
|
||||
const char * test_file_path = nullptr;
|
||||
int parallel_workers = 1;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
if (strcmp(argv[i], "test") == 0) {
|
||||
@@ -9784,6 +9866,17 @@ int main(int argc, char ** argv) {
|
||||
usage(argv);
|
||||
return 1;
|
||||
}
|
||||
} else if (strcmp(argv[i], "-j") == 0) {
|
||||
if (i + 1 < argc) {
|
||||
parallel_workers = atoi(argv[++i]);
|
||||
if (parallel_workers < 1) {
|
||||
usage(argv);
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
usage(argv);
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
usage(argv);
|
||||
return 1;
|
||||
@@ -9836,7 +9929,7 @@ int main(int argc, char ** argv) {
|
||||
false, "", ggml_backend_dev_description(dev),
|
||||
total / 1024 / 1024, free / 1024 / 1024, true));
|
||||
|
||||
bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get(), test_file_path);
|
||||
bool ok = test_backend(backend, dev, mode, op_names_filter, params_filter, output_printer.get(), test_file_path, parallel_workers);
|
||||
|
||||
if (ok) {
|
||||
n_ok++;
|
||||
|
||||
Reference in New Issue
Block a user