Compare commits

...

13 Commits
b8739 ... b8752

Author SHA1 Message Date
Adrien Gallouët
05b3caaa48 common : add callback interface for download progress (#21735)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-04-10 22:17:00 +02:00
MoonRide303
e62fa13c24 model : make Gemma 4 shared-KV tail attn_k tensors optional on load (#21739) 2026-04-10 21:45:50 +02:00
Rithik Sharma
bfd1f453cb ggml-webgpu: support non-square subgroup matrix configs for Intel GPUs (#21669) 2026-04-10 10:52:38 -07:00
Chen Yuan
e4fed9d08d ggml-webgpu: address quantization precision and backend lifecycle managment (#21521)
* ggml(webgpu): fix the busy-polls in Emscripten  in the waitAny after #20618, and remove the busy webgpu log

* Merge with upstream

* Fix GET_ROWS packed integer NaN when using f16 as memory buffer in shader quants

* Update Unary wgsl EXP and EXPM1 for f16 stability

* Fix GET_ROWS IQ4_XS strcut for NaN f16 canonicalization

* Fix numerical percision for unary sqrt when working with f16

* Fix NaN canonicalization for packed integers using f16

* Update err threshold for binary div ops when using f16

* backend: Keep one Dawn/WebGPU instance alive for the lifetime of the static backend

* clean: uncomment existing code logs

* clean: clean the unncessary debug info

* Refactor and generalize dequant helpers

* Remove deprecated quant structs

* Refactor shader defines to reduce repetition

* Remove error override for F16 type

* fix: fix the accidential removal of the proper initialization of ctx

* clean: clean legacy and format code

* fix: did not modify tests ops

---------

Co-authored-by: Jeremy J. Hartmann <jeremy@mtion.tv>
2026-04-10 10:52:01 -07:00
Adrien Gallouët
5dd102539b server : ignore --alias when using --models-preset (#21380)
I'm not sure what the purpose of keeping `--alias` was when using
`--models-preset`, but the result is really weird, as shown in the
following logs:

    $ build/bin/llama-server --models-preset preset.ini --alias "Gemma 4 E4B UD Q8_K_XL"
    ...
    init: using 31 threads for HTTP server
    srv   load_models: Loaded 2 cached model presets
    srv   load_models: Loaded 1 custom model presets from preset.ini
    main: failed to initialize router models: alias 'Gemma 4 E4B UD Q8_K_XL' for model 'angt/test-split-model-stories260K:F32' conflicts with existing model name

So I propose to simply ignore `--alias` too in this case. With this
commit, the server starts in routing mode correctly.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-04-10 17:42:56 +02:00
Adrien Gallouët
fb38d6f278 common : fix when loading a cached HF models with unavailable API (#21670)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-04-10 16:37:46 +02:00
Johannes Gäßler
0893f50f2d common: mark --split-mode tensor as experimental (#21684) 2026-04-10 12:27:27 +02:00
Aleksander Grygier
f989a6e39e webui: Static build output improvements (#21667)
* refactor: Build improvements

* chore: Formatting + package lock update
2026-04-10 11:49:47 +02:00
Berk Idem
d7ff074c87 common : enable reasoning budget sampler for gemma4 (#21697)
* fix: enable reasoning budget sampler for gemma4

Add thinking_start_tag and thinking_end_tag to
common_chat_params_init_gemma4(). Without these, the reasoning
budget sampler never activates for gemma4.

Make the newline after "thought" optional in the PEG parser to
handle budget=0 (sampler forces end tag before the newline).

Add test case for empty thinking block.

Fixes #21487

* use p.space() instead of p.optional(p.literal("\n")) in gemma4 thought parser
2026-04-10 11:49:14 +02:00
Belem Zhang
3f8752b559 docs : fix broken link to ggml-openvino in OPENVINO.md (#21709) 2026-04-10 09:50:08 +02:00
Jeff Bolz
7b69125331 vulkan: Support Q1_0 (#21539)
* vulkan: Support Q1_0

* use get_dm
2026-04-10 08:35:27 +02:00
Adrien Gallouët
e095a482a0 common : add fluidity to the progress bar (#21671)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-04-10 08:24:53 +02:00
Aman Gupta
e34f042154 CUDA: fuse muls (#21665) 2026-04-10 10:24:09 +08:00
38 changed files with 13660 additions and 743 deletions

View File

@@ -291,14 +291,16 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
hf_tag = "default";
}
const bool offline = params.offline;
std::string model_endpoint = get_model_endpoint();
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";
// prepare local path for caching
auto preset_fname = clean_file_name(hf_repo + "_preset.ini");
auto preset_path = fs_get_cache_file(preset_fname);
const int status = common_download_file_single(preset_url, preset_path, params.hf_token, offline);
common_download_opts opts;
opts.bearer_token = params.hf_token;
opts.offline = params.offline;
const int status = common_download_file_single(preset_url, preset_path, opts);
const bool has_preset = status >= 200 && status < 400;
// remote preset is optional, so we don't error out if not found
@@ -341,10 +343,10 @@ static handle_model_result common_params_handle_model(struct common_params_model
model.hf_file = model.path;
model.path = "";
}
common_download_model_opts opts;
opts.download_mmproj = true;
common_download_opts opts;
opts.bearer_token = bearer_token;
opts.offline = offline;
auto download_result = common_download_model(model, bearer_token, opts);
auto download_result = common_download_model(model, opts, true);
if (download_result.model_path.empty()) {
LOG_ERR("error: failed to download model from Hugging Face\n");
@@ -365,9 +367,10 @@ static handle_model_result common_params_handle_model(struct common_params_model
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
}
common_download_model_opts opts;
common_download_opts opts;
opts.bearer_token = bearer_token;
opts.offline = offline;
auto download_result = common_download_model(model, bearer_token, opts);
auto download_result = common_download_model(model, opts);
if (download_result.model_path.empty()) {
LOG_ERR("error: failed to download model from %s\n", model.url.c_str());
exit(1);
@@ -2353,7 +2356,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"- none: use one GPU only\n"
"- layer (default): split layers and KV across GPUs (pipelined)\n"
"- row: split weight across GPUs by rows (parallelized)\n"
"- tensor: split weights and KV across GPUs (parallelized)",
"- tensor: split weights and KV across GPUs (parallelized, EXPERIMENTAL)",
[](common_params & params, const std::string & value) {
if (value == "none") {
params.split_mode = LLAMA_SPLIT_MODE_NONE;

View File

@@ -1083,7 +1083,9 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
data.supports_thinking = true;
data.supports_thinking = true;
data.thinking_start_tag = "<|channel>thought";
data.thinking_end_tag = "<channel|>";
data.preserved_tokens = {
"<|channel>",
@@ -1102,9 +1104,9 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
auto start = p.rule("start", p.prefix(inputs.generation_prompt, "<|channel>"));
if (extract_reasoning) {
p.rule("thought", p.literal("<|channel>thought\n") + p.reasoning(p.until("<channel|>")) + p.literal("<channel|>"));
p.rule("thought", p.literal("<|channel>thought") + p.space() + p.reasoning(p.until("<channel|>")) + p.literal("<channel|>"));
} else {
p.rule("thought", p.content(p.literal("<|channel>thought\n") + p.until("<channel|>") + p.literal("<channel|>")));
p.rule("thought", p.content(p.literal("<|channel>thought") + p.space() + p.until("<channel|>") + p.literal("<channel|>")));
}
auto thought = (p.peek(p.literal("<|channel>")) + p.ref("thought")) | p.negate(p.literal("<|channel>"));

View File

@@ -114,7 +114,7 @@ std::pair<std::string, std::string> common_download_split_repo_tag(const std::st
return {hf_repo, tag};
}
class ProgressBar {
class ProgressBar : public common_download_callback {
static inline std::mutex mutex;
static inline std::map<const ProgressBar *, int> lines;
static inline int max_line = 0;
@@ -138,7 +138,11 @@ class ProgressBar {
}
public:
ProgressBar(const std::string & url = "") : filename(url) {
ProgressBar() = default;
void on_start(const common_download_progress & p) override {
filename = p.url;
if (auto pos = filename.rfind('/'); pos != std::string::npos) {
filename = filename.substr(pos + 1);
}
@@ -156,13 +160,13 @@ public:
}
}
~ProgressBar() {
void on_done(const common_download_progress &, bool) override {
std::lock_guard<std::mutex> lock(mutex);
cleanup(this);
}
void update(size_t current, size_t total) {
if (!total || !is_output_a_tty()) {
void on_update(const common_download_progress & p) override {
if (!p.total || !is_output_a_tty()) {
return;
}
@@ -174,17 +178,17 @@ public:
}
int lines_up = max_line - lines[this];
size_t bar = 55 - len;
size_t pct = (100 * current) / total;
size_t pos = (bar * current) / total;
size_t bar = (55 - len) * 2;
size_t pct = (100 * p.downloaded) / p.total;
size_t pos = (bar * p.downloaded) / p.total;
if (lines_up > 0) {
std::cout << "\033[" << lines_up << "A";
}
std::cout << '\r' << "Downloading " << filename << " ";
for (size_t i = 0; i < bar; ++i) {
std::cout << (i < pos ? "" : " ");
for (size_t i = 0; i < bar; i += 2) {
std::cout << (i + 1 < pos ? "" : (i < pos ? "" : " "));
}
std::cout << std::setw(4) << pct << "%\033[K";
@@ -193,7 +197,7 @@ public:
}
std::cout << '\r' << std::flush;
if (current == total) {
if (p.downloaded == p.total) {
cleanup(this);
}
}
@@ -206,8 +210,8 @@ static bool common_pull_file(httplib::Client & cli,
const std::string & resolve_path,
const std::string & path_tmp,
bool supports_ranges,
size_t existing_size,
size_t & total_size) {
common_download_progress & p,
common_download_callback * callback) {
std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app);
if (!ofs.is_open()) {
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str());
@@ -215,29 +219,27 @@ static bool common_pull_file(httplib::Client & cli,
}
httplib::Headers headers;
if (supports_ranges && existing_size > 0) {
headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
if (supports_ranges && p.downloaded > 0) {
headers.emplace("Range", "bytes=" + std::to_string(p.downloaded) + "-");
}
const char * func = __func__; // avoid __func__ inside a lambda
size_t downloaded = existing_size;
size_t progress_step = 0;
ProgressBar bar(resolve_path);
auto res = cli.Get(resolve_path, headers,
[&](const httplib::Response &response) {
if (existing_size > 0 && response.status != 206) {
if (p.downloaded > 0 && response.status != 206) {
LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status);
return false;
}
if (existing_size == 0 && response.status != 200) {
if (p.downloaded == 0 && response.status != 200) {
LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status);
return false;
}
if (total_size == 0 && response.has_header("Content-Length")) {
if (p.total == 0 && response.has_header("Content-Length")) {
try {
size_t content_length = std::stoull(response.get_header_value("Content-Length"));
total_size = existing_size + content_length;
p.total = p.downloaded + content_length;
} catch (const std::exception &e) {
LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what());
}
@@ -250,11 +252,13 @@ static bool common_pull_file(httplib::Client & cli,
LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str());
return false;
}
downloaded += len;
p.downloaded += len;
progress_step += len;
if (progress_step >= total_size / 1000 || downloaded == total_size) {
bar.update(downloaded, total_size);
if (progress_step >= p.total / 1000 || p.downloaded == p.total) {
if (callback) {
callback->on_update(p);
}
progress_step = 0;
}
return true;
@@ -275,28 +279,13 @@ static bool common_pull_file(httplib::Client & cli,
// download one single file from remote URL to local path
// returns status code or -1 on error
static int common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token,
const common_header_list & custom_headers,
bool skip_etag = false) {
static int common_download_file_single_online(const std::string & url,
const std::string & path,
const common_download_opts & opts,
bool skip_etag) {
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;
auto [cli, parts] = common_http_client(url);
httplib::Headers headers;
for (const auto & h : custom_headers) {
headers.emplace(h.first, h.second);
}
if (headers.find("User-Agent") == headers.end()) {
headers.emplace("User-Agent", "llama-cpp/" + build_info);
}
if (!bearer_token.empty()) {
headers.emplace("Authorization", "Bearer " + bearer_token);
}
cli.set_default_headers(headers);
const bool file_exists = std::filesystem::exists(path);
if (file_exists && skip_etag) {
@@ -304,6 +293,20 @@ static int common_download_file_single_online(const std::string & url,
return 304; // 304 Not Modified - fake cached response
}
auto [cli, parts] = common_http_client(url);
httplib::Headers headers;
for (const auto & h : opts.headers) {
headers.emplace(h.first, h.second);
}
if (headers.find("User-Agent") == headers.end()) {
headers.emplace("User-Agent", "llama-cpp/" + build_info);
}
if (!opts.bearer_token.empty()) {
headers.emplace("Authorization", "Bearer " + opts.bearer_token);
}
cli.set_default_headers(headers);
std::string last_etag;
if (file_exists) {
last_etag = read_etag(path);
@@ -326,10 +329,11 @@ static int common_download_file_single_online(const std::string & url,
etag = head->get_header_value("ETag");
}
size_t total_size = 0;
common_download_progress p;
p.url = url;
if (head->has_header("Content-Length")) {
try {
total_size = std::stoull(head->get_header_value("Content-Length"));
p.total = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) {
LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what());
}
@@ -357,13 +361,17 @@ static int common_download_file_single_online(const std::string & url,
{ // silent
std::error_code ec;
std::filesystem::path p(path);
std::filesystem::create_directories(p.parent_path(), ec);
std::filesystem::create_directories(std::filesystem::path(path).parent_path(), ec);
}
bool success = false;
const std::string path_temporary = path + ".downloadInProgress";
int delay = retry_delay_seconds;
if (opts.callback) {
opts.callback->on_start(p);
}
for (int i = 0; i < max_attempts; ++i) {
if (i) {
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
@@ -378,28 +386,38 @@ static int common_download_file_single_online(const std::string & url,
existing_size = std::filesystem::file_size(path_temporary);
} else if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
return -1;
break;
}
}
p.downloaded = existing_size;
LOG_DBG("%s: downloading from %s to %s (etag:%s)...\n",
__func__, common_http_show_masked_url(parts).c_str(),
path_temporary.c_str(), etag.c_str());
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) {
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, p, opts.callback)) {
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return -1;
break;
}
if (!etag.empty() && !skip_etag) {
write_etag(path, etag);
}
return head->status;
success = true;
break;
}
}
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
return -1; // max attempts reached
if (opts.callback) {
opts.callback->on_done(p, success);
}
if (!success) {
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
return -1; // max attempts reached
}
return head->status;
}
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
@@ -438,12 +456,15 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
int common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline,
const common_header_list & headers,
const common_download_opts & opts,
bool skip_etag) {
if (!offline) {
return common_download_file_single_online(url, path, bearer_token, headers, skip_etag);
if (!opts.offline) {
ProgressBar tty_cb;
common_download_opts online_opts = opts;
if (!online_opts.callback) {
online_opts.callback = &tty_cb;
}
return common_download_file_single_online(url, path, online_opts, skip_etag);
}
if (!std::filesystem::exists(path)) {
@@ -452,6 +473,16 @@ int common_download_file_single(const std::string & url,
}
LOG_DBG("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
// notify the callback that the file was cached
if (opts.callback) {
common_download_progress p;
p.url = url;
p.cached = true;
opts.callback->on_start(p);
opts.callback->on_done(p, true);
}
return 304; // Not Modified - fake cached response
}
@@ -631,16 +662,16 @@ struct hf_plan {
hf_cache::hf_file mmproj;
};
static hf_plan get_hf_plan(const common_params_model & model,
const std::string & token,
const common_download_model_opts & opts) {
static hf_plan get_hf_plan(const common_params_model & model,
const common_download_opts & opts,
bool download_mmproj) {
hf_plan plan;
hf_cache::hf_files all;
auto [repo, tag] = common_download_split_repo_tag(model.hf_repo);
if (!opts.offline) {
all = hf_cache::get_repo_files(repo, token);
all = hf_cache::get_repo_files(repo, opts.bearer_token);
}
if (all.empty()) {
all = hf_cache::get_cached_files(repo);
@@ -675,7 +706,7 @@ static hf_plan get_hf_plan(const common_params_model & model,
plan.primary = primary;
plan.model_files = get_split_files(all, primary);
if (opts.download_mmproj) {
if (download_mmproj) {
plan.mmproj = find_best_mmproj(all, primary.path);
}
@@ -710,10 +741,9 @@ static std::vector<download_task> get_url_tasks(const common_params_model & mode
return tasks;
}
common_download_model_result common_download_model(const common_params_model & model,
const std::string & bearer_token,
const common_download_model_opts & opts,
const common_header_list & headers) {
common_download_model_result common_download_model(const common_params_model & model,
const common_download_opts & opts,
bool download_mmproj) {
common_download_model_result result;
std::vector<download_task> tasks;
hf_plan hf;
@@ -721,7 +751,7 @@ common_download_model_result common_download_model(const common_params_model
bool is_hf = !model.hf_repo.empty();
if (is_hf) {
hf = get_hf_plan(model, bearer_token, opts);
hf = get_hf_plan(model, opts, download_mmproj);
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
@@ -742,8 +772,8 @@ common_download_model_result common_download_model(const common_params_model
std::vector<std::future<bool>> futures;
for (const auto & task : tasks) {
futures.push_back(std::async(std::launch::async,
[&task, &bearer_token, offline = opts.offline, &headers, is_hf]() {
int status = common_download_file_single(task.url, task.path, bearer_token, offline, headers, is_hf);
[&task, &opts, is_hf]() {
int status = common_download_file_single(task.url, task.path, opts, is_hf);
return is_http_status_ok(status);
}
));
@@ -879,7 +909,9 @@ std::string common_docker_resolve_model(const std::string & docker) {
std::string local_path = fs_get_cache_file(model_filename);
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
const int http_status = common_download_file_single(blob_url, local_path, token, false, {});
common_download_opts opts;
opts.bearer_token = token;
const int http_status = common_download_file_single(blob_url, local_path, opts);
if (!is_http_status_ok(http_status)) {
throw std::runtime_error("Failed to download Docker Model");
}

View File

@@ -8,6 +8,21 @@ struct common_params_model;
using common_header = std::pair<std::string, std::string>;
using common_header_list = std::vector<common_header>;
struct common_download_progress {
std::string url;
size_t downloaded = 0;
size_t total = 0;
bool cached = false;
};
class common_download_callback {
public:
virtual ~common_download_callback() = default;
virtual void on_start(const common_download_progress & p) = 0;
virtual void on_update(const common_download_progress & p) = 0;
virtual void on_done(const common_download_progress & p, bool ok) = 0;
};
struct common_remote_params {
common_header_list headers;
long timeout = 0; // in seconds, 0 means no timeout
@@ -31,10 +46,12 @@ struct common_cached_model_info {
}
};
// Options for common_download_model
struct common_download_model_opts {
bool download_mmproj = false;
bool offline = false;
// Options for common_download_model and common_download_file_single
struct common_download_opts {
std::string bearer_token;
common_header_list headers;
bool offline = false;
common_download_callback * callback = nullptr;
};
// Result of common_download_model
@@ -69,9 +86,8 @@ struct common_download_model_result {
// returns result with model_path and mmproj_path (empty on failure)
common_download_model_result common_download_model(
const common_params_model & model,
const std::string & bearer_token,
const common_download_model_opts & opts = {},
const common_header_list & headers = {}
const common_download_opts & opts = {},
bool download_mmproj = false
);
// returns list of cached models
@@ -82,9 +98,7 @@ std::vector<common_cached_model_info> common_list_cached_models();
// skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash)
int common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline,
const common_header_list & headers = {},
const common_download_opts & opts = {},
bool skip_etag = false);
// resolve and download model from Docker registry

View File

@@ -3,7 +3,7 @@
> [!NOTE]
> Performance and memory optimizations, accuracy validation, broader quantization coverage, broader operator and model support are work in progress.
[OpenVINO](https://docs.openvino.ai/) is an open-source toolkit for optimizing and deploying high-performance AI inference, specifically designed for Intel hardware, including CPUs, GPUs, and NPUs, in the cloud, on-premises, and on the edge. [OpenVINO backend for llama.cpp](../../src/ggml-openvino) enables hardware-accelerated inference on **Intel® CPUs, GPUs, and NPUs** while remaining compatible with the existing **GGUF model ecosystem**. The backend translates GGML compute graphs into OpenVINO graphs and leverages graph compilation, kernel fusion, and device-specific optimizations to improve inference performance on supported Intel hardware.
[OpenVINO](https://docs.openvino.ai/) is an open-source toolkit for optimizing and deploying high-performance AI inference, specifically designed for Intel hardware, including CPUs, GPUs, and NPUs, in the cloud, on-premises, and on the edge. [OpenVINO backend for llama.cpp](../../ggml/src/ggml-openvino) enables hardware-accelerated inference on **Intel® CPUs, GPUs, and NPUs** while remaining compatible with the existing **GGUF model ecosystem**. The backend translates GGML compute graphs into OpenVINO graphs and leverages graph compilation, kernel fusion, and device-specific optimizations to improve inference performance on supported Intel hardware.
The OpenVINO backend is implemented in `ggml/src/ggml-openvino` and provides a translation layer for core GGML operations. The OpenVINO backend replaces the standard GGML graph execution path with Intel's OpenVINO inference engine. This approach allows the same GGUF model file to run on Intel CPUs, Intel GPUs (integrated and discrete), and Intel NPUs without changes to the model or the rest of the llama.cpp stack. When a `ggml_cgraph` is dispatched to OpenVINO backend, it:

View File

@@ -472,6 +472,36 @@ void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst,
}
}
void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
switch (n_fuse) {
case 2:
ggml_cuda_op_fused_binbcast_impl<op_mul, 2>(ctx, dst);
break;
case 3:
ggml_cuda_op_fused_binbcast_impl<op_mul, 3>(ctx, dst);
break;
case 4:
ggml_cuda_op_fused_binbcast_impl<op_mul, 4>(ctx, dst);
break;
case 5:
ggml_cuda_op_fused_binbcast_impl<op_mul, 5>(ctx, dst);
break;
case 6:
ggml_cuda_op_fused_binbcast_impl<op_mul, 6>(ctx, dst);
break;
case 7:
ggml_cuda_op_fused_binbcast_impl<op_mul, 7>(ctx, dst);
break;
case 8:
ggml_cuda_op_fused_binbcast_impl<op_mul, 8>(ctx, dst);
break;
default:
GGML_ASSERT(false && "Unsupported n_fuse value");
}
}
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];

View File

@@ -9,3 +9,4 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);

View File

@@ -3758,10 +3758,10 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
continue;
}
if (node->op == GGML_OP_ADD) {
if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
int n_fuse = 0;
ggml_op ops[8];
std::fill(ops, ops + 8, GGML_OP_ADD);
std::fill(ops, ops + 8, node->op);
for (; n_fuse <= 6; ++n_fuse){
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
@@ -3778,13 +3778,17 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
n_fuse++;
if (n_fuse > 1) {
ggml_tensor fused_add_node;
memcpy(&fused_add_node, node, sizeof(ggml_tensor));
ggml_tensor fused_node;
memcpy(&fused_node, node, sizeof(ggml_tensor));
for (int j = 0; j < n_fuse - 1; ++j) {
fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
}
fused_node.data = cgraph->nodes[i + n_fuse - 1]->data;
if (node->op == GGML_OP_ADD) {
ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse);
} else {
ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse);
}
fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
i += n_fuse - 1;
continue;

View File

@@ -3512,6 +3512,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
}
#endif
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q1_0], matmul_q1_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -3541,6 +3542,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
}
#endif
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
@@ -3602,6 +3604,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif
if (device->coopmat_acc_f16_support) {
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -3624,6 +3627,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
} else {
CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -3658,6 +3662,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
#endif
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
@@ -3721,6 +3726,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
@@ -3767,6 +3773,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
@@ -3811,6 +3818,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
@@ -3884,6 +3892,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
@@ -3928,6 +3937,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0].f32acc, matmul_id_subgroup_q1_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
@@ -3954,6 +3964,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0].f32acc, matmul_id_q1_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
@@ -4051,6 +4062,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q1_0][i], "mul_mat_vec_q1_0_f32_f32", arr_dmmv_q1_0_f32_f32_len[reduc], arr_dmmv_q1_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
@@ -4075,6 +4087,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q1_0][i], "mul_mat_vec_q1_0_f16_f32", arr_dmmv_q1_0_f16_f32_len[reduc], arr_dmmv_q1_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
@@ -4125,6 +4138,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q1_0], "mul_mat_vec_id_q1_0_f32", arr_dmmv_id_q1_0_f32_f32_len[reduc], arr_dmmv_id_q1_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
@@ -4179,6 +4193,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
// dequant shaders
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q1_0], "dequant_q1_0", dequant_q1_0_len, dequant_q1_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 8, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -4204,6 +4219,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q1_0], "get_rows_q1_0", get_rows_q1_0_len, get_rows_q1_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -4229,6 +4245,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q1_0], "get_rows_q1_0_f32", get_rows_q1_0_f32_len, get_rows_q1_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -4310,6 +4327,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_rte_len, cpy_f32_q1_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
@@ -4317,6 +4335,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
} else {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
@@ -4329,6 +4348,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## rte ## _len, set_rows_q1_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
@@ -4346,6 +4366,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
#undef SET_ROWS
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q1_0], "cpy_q1_0_f32", cpy_q1_0_f32_len, cpy_q1_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q1_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
@@ -6022,6 +6043,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
VK_LOG_DEBUG("ggml_vk_get_to_fp16()");
switch (type) {
case GGML_TYPE_F32:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -6093,6 +6115,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
}
switch (src0_type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -6158,6 +6181,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -6248,6 +6272,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
switch (src0_type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -6316,6 +6341,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -7263,6 +7289,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
}
if (src->type == GGML_TYPE_F32) {
switch (to) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -7277,6 +7304,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
if (to == GGML_TYPE_F32) {
switch (src->type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -15269,6 +15297,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -15383,6 +15412,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -15415,6 +15445,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -15438,6 +15469,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -15452,6 +15484,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (src1_type == GGML_TYPE_F32) {
switch (src0_type) {
case GGML_TYPE_F16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:

View File

@@ -184,6 +184,31 @@ void quantize(uint dst_idx, uint src_idx)
}
#endif
#if defined(DATA_A_Q1_0)
void quantize(uint dst_idx, uint src_idx)
{
float sum_abs = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_Q1_0; j++) {
sum_abs += abs(data_s[src_idx + j]);
}
const float d = sum_abs / QUANT_K_Q1_0;
data_q[dst_idx].d = float16_t(d);
[[unroll]] for (int j = 0; j < QUANT_K_Q1_0 / 8; ++j) {
data_q[dst_idx].qs[j] = uint8_t(0);
}
[[unroll]] for (int j = 0; j < QUANT_K_Q1_0; ++j) {
if (data_s[src_idx + j] >= 0.0) {
data_q[dst_idx].qs[j / 8] |= uint8_t(1 << (j % 8));
}
}
}
#endif
#if defined(DATA_A_IQ4_NL)
uint best_index(float x) {
if (x <= kvalues_iq4nl[0]) return 0;

View File

@@ -87,6 +87,23 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
}
#endif
#if defined(DATA_A_Q1_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint bits = uint(data_a[a_offset + ib].qs[iqs / 8u]) >> (iqs % 8u);
return vec2(
(bits & 1u) != 0u ? 1.0f : -1.0f,
(bits & 2u) != 0u ? 1.0f : -1.0f);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint bits = uint(data_a[a_offset + ib].qs[iqs / 8u]) >> (iqs % 8u);
return vec4(
(bits & 1u) != 0u ? 1.0f : -1.0f,
(bits & 2u) != 0u ? 1.0f : -1.0f,
(bits & 4u) != 0u ? 1.0f : -1.0f,
(bits & 8u) != 0u ? 1.0f : -1.0f);
}
#endif
#if defined(DATA_A_IQ1_S)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
@@ -454,6 +471,13 @@ vec2 get_dm(uint ib, uint a_offset) {
}
#endif
#if defined(DATA_A_Q1_0)
vec2 get_dm(uint ib, uint a_offset) {
const float d = float(data_a[a_offset + ib].d);
return vec2(d, 0);
}
#endif
#if defined(DATA_A_MXFP4)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);

View File

@@ -13,6 +13,18 @@ float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2],
return vf16[idx];
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ1_0 {
block_q1_0 block;
};
float16_t dequantFuncQ1_0(const in decodeBufQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint bit = (uint(bl.block.qs[(idx & 0x78) >> 3]) >> (idx & 0x7)) & 1u;
return bit != 0u ? d : -d;
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
block_q4_0_packed16 block;
};
@@ -685,7 +697,9 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
}
#endif
#if defined(DATA_A_Q4_0)
#if defined(DATA_A_Q1_0)
#define dequantFuncA dequantFuncQ1_0
#elif defined(DATA_A_Q4_0)
#define dequantFuncA dequantFuncQ4_0
#elif defined(DATA_A_Q4_1)
#define dequantFuncA dequantFuncQ4_1

View File

@@ -0,0 +1,29 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q1_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid / 4;
const uint ir = tid % 4;
const uint ib = 4*i + ir;
if (ib >= p.nel / 128) {
return;
}
const uint b_idx = 512*i + 128*ir + 8*il;
const float d = float(data_a[ib].d);
const uint bits = uint(data_a[ib].qs[il]);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l] = D_TYPE((bits & (1u << l)) != 0u ? d : -d);
}
}

View File

@@ -130,6 +130,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw);
#elif defined(DATA_A_Q1_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 16;
const uint iqs = idx & 0xfu;
const float d = float(data_a[ib].d);
const uint bits = uint(data_a[ib].qs[iqs]);
buf_a[buf_idx ] = FLOAT_TYPEV2((bits & 0x01u) != 0u ? d : -d, (bits & 0x02u) != 0u ? d : -d);
buf_a[buf_idx + 1] = FLOAT_TYPEV2((bits & 0x04u) != 0u ? d : -d, (bits & 0x08u) != 0u ? d : -d);
buf_a[buf_idx + 2] = FLOAT_TYPEV2((bits & 0x10u) != 0u ? d : -d, (bits & 0x20u) != 0u ? d : -d);
buf_a[buf_idx + 3] = FLOAT_TYPEV2((bits & 0x40u) != 0u ? d : -d, (bits & 0x80u) != 0u ? d : -d);
#elif defined(DATA_A_Q2_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;

View File

@@ -188,6 +188,22 @@ struct block_q8_0_packed16
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q1_0 128
#define QUANT_R_Q1_0 1
struct block_q1_0
{
float16_t d;
uint8_t qs[QUANT_K_Q1_0 / 8];
};
#if defined(DATA_A_Q1_0)
#define QUANT_K QUANT_K_Q1_0
#define QUANT_R QUANT_R_Q1_0
#define QUANT_AUXF 1
#define A_TYPE block_q1_0
#endif
#define QUANT_K_Q8_1 32
#define QUANT_R_Q8_1 1

View File

@@ -45,6 +45,7 @@ std::string target_cpp = "";
const std::vector<std::string> type_names = {
"f32",
"f16",
"q1_0",
"q4_0",
"q4_1",
"q5_0",
@@ -553,7 +554,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
for (const auto& tname : type_names) {
std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
if ((tname == "q1_0") || (tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8";
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4"))
load_vec_quant = "4";
@@ -758,13 +759,13 @@ void process_shaders() {
string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}});
string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}});
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
for (std::string t : {"q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}
for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

View File

@@ -1115,6 +1115,32 @@ class ggml_webgpu_shader_lib {
std::string type_upper = type_str;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
switch (key.src_type)
{
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL:
{
// Quantized types using u32 buffers for portability.
defines.push_back("SRC_TYPE=u32");
defines.push_back("U32_DEQUANT_HELPERS");
break;
}
default:
{
defines.push_back(std::string("SRC_TYPE=") + type_str);
}
}
defines.push_back("BYTE_HELPERS");
defines.push_back(type_upper + "_T");
defines.push_back(type_upper);
@@ -1125,7 +1151,6 @@ class ggml_webgpu_shader_lib {
variant += "_";
variant += type_str;
defines.push_back(std::string("SRC_TYPE=") + type_str);
defines.push_back("DST_TYPE=f32");
if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
@@ -1593,11 +1618,35 @@ class ggml_webgpu_shader_lib {
break;
default:
{
// quantized types
std::string type_upper = src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
switch (context.src0->type)
{
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL:
{
// Quantized types using u32 buffers for portability.
defines.push_back("SRC0_TYPE=u32");
defines.push_back("U32_DEQUANT_HELPERS");
break;
}
default:
{
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
}
}
defines.push_back("BYTE_HELPERS");
defines.push_back(type_upper + "_T");
defines.push_back(type_upper);

View File

@@ -97,6 +97,14 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
/* End Constants */
static inline wgpu::CallbackMode ggml_webgpu_callback_mode() {
#ifdef __EMSCRIPTEN__
return wgpu::CallbackMode::AllowProcessEvents;
#else
return wgpu::CallbackMode::AllowSpontaneous;
#endif
}
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
// their locations.
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
@@ -474,7 +482,7 @@ static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) {
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
ggml_webgpu_callback_mode(),
[&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
callback_status = status;
callback_message = std::string(message);
@@ -494,7 +502,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
std::string callback_message;
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
buffer.MapAsync(mode, offset, size, ggml_webgpu_callback_mode(),
[&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) {
callback_status = status;
callback_message = std::string(message);
@@ -526,7 +534,11 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
wgpu::CommandBuffer commands = encoder.Finish();
ctx->queue.Submit(1, &commands);
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
if (!ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0,
ctx->debug_host_buf.GetSize())) {
GGML_LOG_ERROR("ggml_webgpu: Debug buffer map failed\n");
return;
}
const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
std::cout << "debug[0]: " << debug_data[0] << "\n";
ctx->debug_host_buf.Unmap();
@@ -542,7 +554,7 @@ static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context &
auto ts_bufs = command.timestamp_query_bufs;
wgpu::Future f = ts_bufs.host_buf.MapAsync(
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), ggml_webgpu_callback_mode(),
[ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
@@ -3420,7 +3432,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
ctx->webgpu_global_ctx->instance.WaitAny(
ctx->webgpu_global_ctx->instance.RequestAdapter(
&options, wgpu::CallbackMode::AllowSpontaneous,
&options, ggml_webgpu_callback_mode(),
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
@@ -3449,13 +3461,15 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
#ifndef __EMSCRIPTEN__
// Only support square f16 matrices of size 8 or 16 for now
// Accept f16 subgroup matrix configurations (square or non-square).
// NVIDIA GPUs typically report square configs (e.g. 16x16x16),
// while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16).
// The shaders are already parameterized to handle any M/N/K dimensions.
bool valid_subgroup_matrix_config = false;
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
@@ -3491,8 +3505,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
dev_desc.requiredFeatures = required_features.data();
dev_desc.requiredFeatureCount = required_features.size();
dev_desc.SetDeviceLostCallback(
wgpu::CallbackMode::AllowSpontaneous,
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
ggml_webgpu_callback_mode(),
[ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
if (reason == wgpu::DeviceLostReason::Destroyed) {
return;
}
@@ -3525,7 +3539,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
ctx->webgpu_global_ctx->instance.WaitAny(
ctx->webgpu_global_ctx->adapter.RequestDevice(
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
&dev_desc, ggml_webgpu_callback_mode(),
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
if (status != wgpu::RequestDeviceStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
@@ -3793,6 +3807,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
break;
}
// Head dimensions must be divisible by subgroup matrix dimensions
if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 ||
src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) {
break;
}
// Head dimensions must fit in workgroup memory with minimum tile sizes
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
@@ -4046,6 +4065,13 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
ctx.name = GGML_WEBGPU_NAME;
ctx.device_count = 0;
// Keep one Dawn/WebGPU instance alive for the lifetime of the static backend
// registry. Recreating it on repeated registry lookups can invalidate
// adapter/device references that are still held by the backend/device layer.
if (ctx.webgpu_global_ctx != nullptr && ctx.webgpu_global_ctx->instance != nullptr) {
return &reg;
}
wgpu::InstanceDescriptor instance_descriptor{};
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
instance_descriptor.requiredFeatures = instance_features.data();
@@ -4063,11 +4089,11 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
ctx.webgpu_global_ctx->instance = std::move(inst);
// Probe for adapter support
wgpu::Adapter adapter;
if (ctx.webgpu_global_ctx->instance != nullptr) {
wgpu::RequestAdapterOptions options = {};
// probe for adapter support
ctx.webgpu_global_ctx->instance.WaitAny(
ctx.webgpu_global_ctx->instance.RequestAdapter(
&options, wgpu::CallbackMode::AllowSpontaneous,

View File

@@ -9,35 +9,43 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
#endif
#ifdef U32_DEQUANT_HELPERS
fn load_src0_u16_at(byte_offset: u32) -> u32 {
let word = src0[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
fn load_u16_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> u32 {
let word = buf[byte_offset / 4];
let shift = (byte_offset & 0x2) * 8;
return (word >> shift) & 0xFFFF;
}
fn load_src0_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = src0[word_idx];
if (shift == 0u) {
return lo;
}
let hi = src0[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
fn load_u32_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4;
let shift = (byte_offset & 0x3) * 8;
let lo = buf[word_idx];
let hi = buf[word_idx + 1];
let shifted = (lo >> shift) | (hi << (32 - shift));
return select(shifted, lo, shift == 0);
}
fn load_src0_f16_at(byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_src0_u16_at(byte_offset));
fn load_f16_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_u16_at(buf, byte_offset));
return f16(packed[0]);
}
fn load_f16_as_f32_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> f32 {
let word = buf[byte_offset / 4];
let shift = (byte_offset & 0x2) * 8;
let d_bits = (word >> shift) & 0xFFFF;
return unpack2x16float(d_bits)[0];
}
#endif
#ifdef Q4_0_T
struct q4_0 {
d: f16,
qs: array<f16, 8>
};
#endif
#ifdef Q4_1_T
struct q4_1 {
@@ -47,13 +55,6 @@ struct q4_1 {
};
#endif
#ifdef Q5_0_T
struct q5_0 {
d: f16,
qh: array<f16, 2>,
qs: array<f16, 8>
};
#endif
#ifdef Q5_1_T
struct q5_1 {
@@ -64,12 +65,6 @@ struct q5_1 {
};
#endif
#ifdef Q8_0_T
struct q8_0 {
d: f16,
qs: array<f16, 16>
};
#endif
#ifdef Q8_1_T
struct q8_1 {
@@ -88,14 +83,6 @@ struct q2_K {
};
#endif
#ifdef Q3_K_T
struct q3_K {
hmask: array<f16, 16>,
qs: array<f16, 32>,
scales: array<f16, 6>,
d: f16
};
#endif
#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
@@ -132,64 +119,6 @@ struct q5_K {
};
#endif
#ifdef Q6_K_T
struct q6_K {
ql: array<f16, 64>,
qh: array<f16, 32>,
scales: array<f16, 8>,
d: f16
};
#endif
#ifdef IQ2_XXS_T
struct iq2_xxs {
d: f16,
qs: array<f16, 32>
};
#endif
#ifdef IQ2_XS_T
struct iq2_xs {
d: f16,
qs: array<f16, 32>,
scales: array<f16, 4>
};
#endif
#ifdef IQ2_S_T
struct iq2_s {
d: f16,
qs: array<f16, 32>,
qh: array<f16, 4>,
scales: array<f16, 4>
};
#endif
#ifdef IQ3_XXS_T
struct iq3_xxs {
d: f16,
qs: array<f16, 48>
};
#endif
#ifdef IQ3_S_T
struct iq3_s {
d: f16,
qs: array<f16, 32>,
qh: array<f16, 4>,
signs: array<f16, 16>,
scales: array<f16, 2>
};
#endif
#ifdef IQ1_S_T
struct iq1_s {
d: f16,
qs: array<f16, 16>,
qh: array<f16, 8>
};
#endif
#ifdef IQ1_M_T
struct iq1_m {
qs: array<u32, 8>,
@@ -198,17 +127,9 @@ struct iq1_m {
};
#endif
#ifdef IQ4_NL_T
struct iq4_nl {
d: f16,
qs: array<f16, 8>,
};
#endif
#ifdef IQ4_XS_T
struct iq4_xs {
d: f16,
scales_h: f16,
d_scales_h: u32,
scales_l: u32,
qs: array<u32, 32>
};

View File

@@ -369,35 +369,35 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
#endif
for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
let inter_offset = kv_block * SG_MAT_N;
var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
var acc: subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M>>(&inter_shmem, inter_offset, false, KV_TILE);
var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, 0u, false, HEAD_DIM_QK);
var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, 0u, false, HEAD_DIM_QK);
#ifdef KV_DIRECT
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + 0u, true, params.stride_k1);
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + 0u, true, params.stride_k1);
#else
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
#endif
var t: u32 = 1u;
for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
let h0 = t * SG_MAT_K;
var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h0, false, HEAD_DIM_QK);
var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h0, false, HEAD_DIM_QK);
#ifdef KV_DIRECT
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h0, true, params.stride_k1);
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h0, true, params.stride_k1);
#else
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
#endif
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
q_cur = q0;
k_cur = k0;
let h1 = (t + 1u) * SG_MAT_K;
var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h1, false, HEAD_DIM_QK);
var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h1, false, HEAD_DIM_QK);
#ifdef KV_DIRECT
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h1, true, params.stride_k1);
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h1, true, params.stride_k1);
#else
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
#endif
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
q_cur = q1g;
@@ -407,11 +407,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
// handle odd tail
if (t < HEAD_DIM_QK / SG_MAT_K) {
let h = t * SG_MAT_K;
var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h, false, HEAD_DIM_QK);
var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h, false, HEAD_DIM_QK);
#ifdef KV_DIRECT
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h, true, params.stride_k1);
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h, true, params.stride_k1);
#else
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
#endif
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
q_cur = qn;
@@ -566,7 +566,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
head_dim_block < HEAD_DIM_V;
head_dim_block += num_subgroups * SG_MAT_N) {
// load O submatrix from shared memory
var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(
var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M>>(
&o_shmem,
head_dim_block,
false,
@@ -574,7 +574,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
);
for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
let p_offset = kv_block * SG_MAT_N;
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(
&inter_shmem,
p_offset,
false,
@@ -585,7 +585,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
#ifdef KV_DIRECT
let v_block_row = kv_tile + kv_block * SG_MAT_N;
let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block;
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(
&V,
v_global_offset,
false,
@@ -593,7 +593,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
);
#else
let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V;
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(
&kv_shmem,
v_block_offset + head_dim_block,
false,

View File

@@ -27,17 +27,18 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q4_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q4_0 = src[src_base + offset];
let d = f32(block_q4_0.d);
for (var j: u32 = 0; j < 4; j++) {
let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
for (var j: u32 = 0u; j < 4; j++) {
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at(&src, q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f32(q_byte & 0xFu) - 8.0) * d;
let dst_offset = dst_base + offset * 32 + j * 4 + k;
dst[dst_offset] = q_lo;
dst[dst_offset + 16] = q_hi;
dst[dst_offset + 16u] = q_hi;
}
}
}
@@ -64,17 +65,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q5_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q5_0 = src[src_base + offset];
let d = f32(block_q5_0.d);
let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let qh_packed = load_u32_at(&src, block_byte_base + 2);
for (var j: u32 = 0; j < 4; j++) {
let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
let q_byte_offset = block_byte_base + 6 + j * 4;
let q_packed = load_u32_at(&src, q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
let dst_offset = dst_base + offset * 32 + j * 4 + k;
dst[dst_offset] = q_lo;
dst[dst_offset + 16] = q_hi;
@@ -106,14 +112,15 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q8_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q8_0 = src[src_base + offset];
let d = f32(block_q8_0.d);
for (var j: u32 = 0; j < 8; j++) {
let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
for (var k: u32 = 0; k < 4; k++) {
let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
for (var j: u32 = 0u; j < 8u; j++) {
let q_byte_offset = block_byte_base + 2u + j * 4u;
let q_packed = load_u32_at(&src, q_byte_offset);
for (var k: u32 = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
let dst_offset = dst_base + offset * 32 + j * 4 + k;
let dst_offset = dst_base + offset * 32u + j * 4u + k;
dst[dst_offset] = q_val;
}
}
@@ -152,36 +159,42 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q3_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
// and 2-bits from the last 4 bytes
// Bytes 108-109: f16 scale 'd'
let d = load_f16_as_f32_at(&src, block_byte_base + 108);
// Bytes 96-107: 12 bytes of scales (3 u32s)
let kmask1: u32 = 0x03030303;
let kmask2: u32 = 0x0f0f0f0f;
var scale_vals: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
}
scale_vals[0] = load_u32_at(&src, block_byte_base + 96);
scale_vals[1] = load_u32_at(&src, block_byte_base + 100);
scale_vals[2] = load_u32_at(&src, block_byte_base + 104);
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
// convert arrays of f16 -> u32
// Bytes 0-31: 32 bytes of hmask (8 u32s)
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
}
// Bytes 32-95: 64 bytes of qs (16 u32s)
var qs_vals: array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
for (var i: u32 = 0u; i < 16; i++) {
qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4);
}
var dst_i = dst_base + offset * 256;
var is: u32 = 0;
var m: u32 = 1;
// 2 halves of the block (128 elements each)
for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
// 4 groups (each group has 2 blocks of 16 elements)
@@ -191,11 +204,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let sc = get_byte(scale_vals[is / 4], is % 4);
is++;
let dl = d * (f32(sc) - 32.0);
for (var l: u32 = 0u; l < 16u; l++) {
for (var l: u32 = 0; l < 16; l++) {
let q_idx = q_b_idx + k + l;
let hm_idx = k + l;
let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
let qs_val = (q_byte >> shift) & 3;
dst[dst_i] = (f32(qs_val) - hm) * dl;
@@ -268,21 +283,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q6_K
// 16 blocks of 16 elements each
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes
// convert arrays of f16 -> u32
// Bytes 208-209: f16 scale 'd'
let d = load_f16_as_f32_at(&src, block_byte_base + 208);
// Bytes 0-127: 128 bytes of ql (32 u32s)
var ql_vals: array<u32, 32>;
for (var i: u32 = 0; i < 32; i++) {
ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
}
// Bytes 128-191: 64 bytes of qh (16 u32s)
var qh_vals: array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
for (var i: u32 = 0; i < 16u; i++) {
qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u);
}
// Bytes 192-207: 16 bytes of scales (4 u32s)
var scale_vals: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4);
}
var dst_i = dst_base + offset * 256;
@@ -323,12 +344,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ2_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 32; ib += 4) {
let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
let aux0_offset = block_byte_base + 2 + ib * 2;
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
let aux0 = load_u32_at(&src, aux0_offset);
let aux1 = load_u32_at(&src, aux1_offset);
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
for (var l: u32 = 0; l < 4; l++) {
let ig = get_byte(aux0, l) * 8;
@@ -345,15 +368,19 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
#endif
#ifdef IQ2_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
var dst_i = dst_base + offset * 256;
var scale_vals = array<u32, 2>(
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
load_u32_at(&src, block_byte_base + 66),
load_u32_at(&src, block_byte_base + 70)
);
for (var ib: u32 = 0; ib < 32; ib += 4) {
let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
let db = array<f32, 2>(
@@ -361,7 +388,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
d * (0.5 + f32(s >> 4)) * 0.25
);
for (var l: u32 = 0; l < 4; l++) {
let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF;
let ig = (qs_val & 511) * 8;
let is = qs_val >> 9;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
@@ -379,21 +407,23 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ2_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
var dst_i = dst_base + offset * 256;
var qs_vals : array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
}
var qh_vals = array<u32, 2>(
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
);
var scale_vals = array<u32, 2>(
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
);
var qh_vals: array<u32, 2>;
qh_vals[0] = load_u32_at(&src, block_byte_base + 66);
qh_vals[1] = load_u32_at(&src, block_byte_base + 70);
var scale_vals: array<u32, 2>;
scale_vals[0] = load_u32_at(&src, block_byte_base + 74);
scale_vals[1] = load_u32_at(&src, block_byte_base + 78);
for (var ib: u32 = 0; ib < 8; ib ++) {
let s = get_byte(scale_vals[ib / 4], ib % 4);
let db = array<f32, 2>(
@@ -419,16 +449,17 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ3_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 16; ib += 2) {
let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
let sc_sign = load_u32_at(&src, sc_sign_offset);
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
for (var l: u32 = 0; l < 4; l++) {
let is = (sc_sign >> (7 * l)) & 127;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0);
let ig2 = get_byte(ig_val, 1);
for (var j: u32 = 0; j < 4; j++) {
@@ -448,18 +479,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ3_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
var dst_i = dst_base + offset * 256;
var qh_vals = array<u32, 2>(
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
load_u32_at(&src, block_byte_base + 66),
load_u32_at(&src, block_byte_base + 70)
);
var sign_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4);
}
var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
var scale_vals = load_u32_at(&src, block_byte_base + 106);
for (var ib: u32 = 0; ib < 4; ib++) {
let s = get_byte(scale_vals, ib);
let db = array<f32, 2>(
@@ -472,7 +507,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let sign_w = sign_vals[ib * 2 + k];
for (var l: u32 = 0; l < 4; l++) {
let signs = get_byte(sign_w, l);
let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
for (var j: u32 = 0; j < 4; j++) {
@@ -493,14 +528,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ1_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 8; ib++) {
let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
let dl = d * (2 * f32((qh >> 12) & 7) + 1);
let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF;
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4);
for (var l: u32 = 0; l < 4; l++) {
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
for (var j: u32 = 0; j < 8; j++) {
@@ -560,12 +595,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ4_NL
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
var dst_i = dst_base + offset * 32;
var qs: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
}
for (var j: u32 = 0; j < 16; j++) {
let qsb = get_byte(qs[j / 4], j % 4);
@@ -579,8 +614,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ4_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
let d = unpack2x16float(block.d_scales_h)[0];
let scales_h = block.d_scales_h >> 16;
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 8; ib++) {
let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);

View File

@@ -20,11 +20,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef Q4_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q4_0 = src0[src0_idx_base + offset];
let d = f32(block_q4_0.d);
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
var sum: f32 = 0.0;
for (var j: u32 = 0; j < 4; j++) {
let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at(&src0, q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
@@ -61,12 +62,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef Q5_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q5_0 = src0[src0_idx_base + offset];
let d = f32(block_q5_0.d);
let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
var sum: f32 = 0.0;
let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
let qh_packed = load_u32_at(&src0, block_byte_base + 2);
for (var j: u32 = 0; j < 4; j++) {
let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
let q_byte_offset = block_byte_base + 6 + j * 4;
let q_packed = load_u32_at(&src0, q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
@@ -107,12 +109,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef Q8_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q8_0 = src0[src0_idx_base + offset];
let d = f32(block_q8_0.d);
let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
var sum: f32 = 0.0;
for (var j: u32 = 0; j < 8; j++) {
let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
for (var k: u32 = 0; k < 4; k++) {
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at(&src0, q_byte_offset);
for (var k: u32 = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
@@ -178,31 +181,37 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef Q3_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
// Bytes 108-109: f16 scale 'd'
let d = load_f16_as_f32_at(&src0, block_byte_base + 108);
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
// and 2-bits from the last 4 bytes
// Bytes 96-107: 12 bytes of scales (3 u32s)
let kmask1: u32 = 0x03030303;
let kmask2: u32 = 0x0f0f0f0f;
var scale_vals: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
}
scale_vals[0] = load_u32_at(&src0, block_byte_base + 96);
scale_vals[1] = load_u32_at(&src0, block_byte_base + 100);
scale_vals[2] = load_u32_at(&src0, block_byte_base + 104);
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
// convert arrays of f16 -> u32
// Bytes 0-31: 32 bytes of hmask (8 u32s)
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
hmask_vals[i] = load_u32_at(&src0, block_byte_base + i * 4);
}
// Bytes 32-95: 64 bytes of qs (16 u32s)
var qs_vals: array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
for (var i: u32 = 0u; i < 16; i++) {
qs_vals[i] = load_u32_at(&src0, block_byte_base + 32 + i * 4);
}
var sum = 0.0;
@@ -301,21 +310,27 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef Q6_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes
// convert arrays of f16 -> u32
// Bytes 208-209: f16 scale 'd'
let d = load_f16_as_f32_at(&src0, block_byte_base + 208);
// Bytes 0-127: 128 bytes of ql (32 u32s)
var ql_vals: array<u32, 32>;
for (var i: u32 = 0; i < 32; i++) {
ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
ql_vals[i] = load_u32_at(&src0, block_byte_base + i * 4);
}
// Bytes 128-191: 64 bytes of qh (16 u32s)
var qh_vals: array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
qh_vals[i] = load_u32_at(&src0, block_byte_base + 128 + i * 4);
}
// Bytes 192-207: 16 bytes of scales (4 u32s)
var scale_vals: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
scale_vals[i] = load_u32_at(&src0, block_byte_base + 192 + i * 4);
}
var sum = 0.0;
@@ -358,13 +373,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ2_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 32; ib += 4) {
let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
let aux0_offset = block_byte_base + 2 + ib * 2;
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
let aux0 = load_u32_at(&src0, aux0_offset);
let aux1 = load_u32_at(&src0, aux1_offset);
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
for (var l: u32 = 0; l < 4; l++) {
let ig = get_byte(aux0, l) * 8;
@@ -384,13 +401,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ2_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var scale_vals = array<u32, 2>(
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
load_u32_at(&src0, block_byte_base + 66),
load_u32_at(&src0, block_byte_base + 70)
);
var sum = 0.0;
for (var ib: u32 = 0; ib < 32; ib += 4) {
let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
@@ -399,7 +418,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
d * (0.5 + f32(s >> 4)) * 0.25
);
for (var l: u32 = 0; l < 4; l++) {
let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
let qs_val = load_u32_at(&src0, qs_offset) & 0xFFFF;
let ig = (qs_val & 511) * 8;
let is = qs_val >> 9;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
@@ -418,21 +438,23 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ2_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var qs_vals : array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
qs_vals[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4);
}
var qh_vals = array<u32, 2>(
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
);
var scale_vals = array<u32, 2>(
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
);
var qh_vals: array<u32, 2>;
qh_vals[0] = load_u32_at(&src0, block_byte_base + 66);
qh_vals[1] = load_u32_at(&src0, block_byte_base + 70);
var scale_vals: array<u32, 2>;
scale_vals[0] = load_u32_at(&src0, block_byte_base + 74);
scale_vals[1] = load_u32_at(&src0, block_byte_base + 78);
var sum = 0.0;
for (var ib: u32 = 0; ib < 8; ib ++) {
let s = get_byte(scale_vals[ib / 4], ib % 4);
@@ -460,17 +482,18 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ3_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 16; ib += 2) {
let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
let sc_sign = load_u32_at(&src0, sc_sign_offset);
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
for (var l: u32 = 0; l < 4; l++) {
let is = (sc_sign >> (7 * l)) & 127;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0);
let ig2 = get_byte(ig_val, 1);
for (var j: u32 = 0; j < 4; j++) {
@@ -491,18 +514,22 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ3_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var qh_vals = array<u32, 2>(
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
load_u32_at(&src0, block_byte_base + 66),
load_u32_at(&src0, block_byte_base + 70)
);
var sign_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
sign_vals[i] = load_u32_at(&src0, block_byte_base + 74 + i * 4);
}
var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
var scale_vals = load_u32_at(&src0, block_byte_base + 106);
var sum = 0.0;
for (var ib: u32 = 0; ib < 4; ib++) {
let s = get_byte(scale_vals, ib);
@@ -516,7 +543,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let sign_w = sign_vals[ib * 2 + k];
for (var l: u32 = 0; l < 4; l++) {
let signs = get_byte(sign_w, l);
let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
for (var j: u32 = 0; j < 4; j++) {
@@ -538,15 +565,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ1_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 8; ib++) {
let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
let dl = d * (2 * f32((qh >> 12) & 7) + 1);
let qh = load_u32_at(&src0, block_byte_base + 34 + ib * 2) & 0xFFFF;
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
let qs_w = load_u32_at(&src0, block_byte_base + 2 + ib * 4);
for (var l: u32 = 0; l < 4; l++) {
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
for (var j: u32 = 0; j < 8; j++) {
@@ -610,13 +637,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ4_NL
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
var src1_i = src1_idx_base + offset * 32;
var sum = 0.0;
var qs: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
qs[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4);
}
for (var j: u32 = 0; j < 16; j++) {
let qsb = get_byte(qs[j / 4], j % 4);
@@ -631,8 +658,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ4_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
let d = unpack2x16float(block.d_scales_h)[0];
let scales_h = block.d_scales_h >> 16;
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 8; ib++) {

View File

@@ -84,11 +84,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
let d = load_f16_at(&src0, block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
let q_packed = load_u32_at(&src0, q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@@ -125,12 +125,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
let m = load_src0_f16_at(block_byte_base + 2u);
let d = load_f16_at(&src0, block_byte_base);
let m = load_f16_at(&src0, block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
let q_packed = load_u32_at(&src0, q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_lo = f16(q_byte & 0xF) * d + m;
@@ -171,12 +171,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
let d = load_f16_at(&src0, block_byte_base);
let qh_packed = load_u32_at(&src0, block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let q_packed = load_u32_at(&src0, q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@@ -225,14 +225,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
let m = load_src0_f16_at(block_byte_base + 2u);
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
let d = load_f16_at(&src0, block_byte_base);
let m = load_f16_at(&src0, block_byte_base + 2u);
let qh_packed = load_u32_at(&src0, block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let q_packed = load_u32_at(&src0, q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@@ -277,11 +277,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
let d = load_f16_at(&src0, block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
let q_packed = load_u32_at(&src0, q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
@@ -317,12 +317,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
let m = load_src0_f16_at(block_byte_base + 2u);
let d = load_f16_at(&src0, block_byte_base);
let m = load_f16_at(&src0, block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
let q_packed = load_u32_at(&src0, q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
@@ -359,8 +359,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base + 80u);
let dmin = load_src0_f16_at(block_byte_base + 82u);
let d = load_f16_at(&src0, block_byte_base + 80u);
let dmin = load_f16_at(&src0, block_byte_base + 82u);
// Decode the element at position k_in_block
let block_of_32 = k_in_block / 32u;
@@ -373,14 +373,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let is = k_in_block / 16u;
let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u));
let sc_packed = load_u32_at(&src0, block_byte_base + 4u * (is / 4u));
let sc = get_byte(sc_packed, is % 4u);
let dl = d * f16(sc & 0xFu);
let ml = dmin * f16(sc >> 4u);
let q_idx = q_b_idx + k + l;
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 3u;
@@ -413,7 +413,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base + 108u);
let d = load_f16_at(&src0, block_byte_base + 108u);
// Load and unpack scales
let kmask1: u32 = 0x03030303u;
@@ -421,7 +421,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
var scale_vals: array<u32, 4>;
for (var i: u32 = 0u; i < 4u; i++) {
scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i);
scale_vals[i] = load_u32_at(&src0, block_byte_base + 96u + 4u * i);
}
var tmp: u32 = scale_vals[2];
@@ -433,12 +433,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load hmask and qs arrays
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0u; i < 8u; i++) {
hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i);
hmask_vals[i] = load_u32_at(&src0, block_byte_base + 4u * i);
}
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16u; i++) {
qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i);
qs_vals[i] = load_u32_at(&src0, block_byte_base + 32u + 4u * i);
}
let half = k_in_block / 128u; // 0 or 1
@@ -499,13 +499,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
let dmin = load_src0_f16_at(block_byte_base + 2u);
let d = load_f16_at(&src0, block_byte_base);
let dmin = load_f16_at(&src0, block_byte_base + 2u);
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
}
// Map k_in_block to loop structure:
@@ -541,7 +541,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
@@ -575,13 +575,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
let dmin = load_src0_f16_at(block_byte_base + 2u);
let d = load_f16_at(&src0, block_byte_base);
let dmin = load_f16_at(&src0, block_byte_base + 2u);
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
}
// The original loop processes elements in groups of 64
@@ -621,11 +621,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u));
let q_packed = load_u32_at(&src0, block_byte_base + 48u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u));
let qh_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (l / 4u));
let qh_byte = get_byte(qh_packed, l % 4u);
@@ -673,17 +673,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load only ql13 word needed
let ql13_flat = ql_b_idx + l;
let ql13 = load_src0_u32_at(block_byte_base + ql13_flat);
let ql13 = load_u32_at(&src0, block_byte_base + ql13_flat);
let ql13_b = get_byte(ql13, 0u);
// Load only ql24 word needed
let ql24_flat = ql_b_idx + l + 32u;
let ql24 = load_src0_u32_at(block_byte_base + ql24_flat);
let ql24 = load_u32_at(&src0, block_byte_base + ql24_flat);
let ql24_b = get_byte(ql24, 0u);
// Load only qh word needed
let qh_flat = qh_b_idx + l;
let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat);
let qh = load_u32_at(&src0, block_byte_base + 128u + qh_flat);
let qh_b = get_byte(qh, 0u);
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
@@ -694,10 +694,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load only the scale word needed
let is = l / 16u;
let sc_idx = sc_b_idx + is + quarter * 2u;
let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx);
let sc = load_u32_at(&src0, block_byte_base + 192u + sc_idx);
let sc_val = get_byte_i32(sc, 0u);
let d = load_src0_f16_at(block_byte_base + 208u);
let d = load_f16_at(&src0, block_byte_base + 208u);
var q_val: f16;
if (quarter == 0u) {

View File

@@ -65,10 +65,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_src0_f16_at(block_byte_base));
let d = f32(load_f16_at(&src0, block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
let q_packed = load_u32_at(&src0, q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
@@ -98,11 +98,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_src0_f16_at(block_byte_base));
let m = f32(load_src0_f16_at(block_byte_base + 2u));
let d = f32(load_f16_at(&src0, block_byte_base));
let m = f32(load_f16_at(&src0, block_byte_base + 2u));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
let q_packed = load_u32_at(&src0, q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
@@ -132,12 +132,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_src0_f16_at(block_byte_base));
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
let d = f32(load_f16_at(&src0, block_byte_base));
let qh_packed = load_u32_at(&src0, block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let q_packed = load_u32_at(&src0, q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@@ -176,13 +176,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_src0_f16_at(block_byte_base));
let m = load_src0_f16_at(block_byte_base + 2u);
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
let d = f32(load_f16_at(&src0, block_byte_base));
let m = load_f16_at(&src0, block_byte_base + 2u);
let qh_packed = load_u32_at(&src0, block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let q_packed = load_u32_at(&src0, q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@@ -221,11 +221,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_src0_f16_at(block_byte_base));
let d = f32(load_f16_at(&src0, block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
let q_packed = load_u32_at(&src0, q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
@@ -254,12 +254,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_src0_f16_at(block_byte_base));
let m = load_src0_f16_at(block_byte_base + 2u);
let d = f32(load_f16_at(&src0, block_byte_base));
let m = load_f16_at(&src0, block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
let q_packed = load_u32_at(&src0, q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d + f32(m);
@@ -309,13 +309,13 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = ix; i < nb; i += 2u) {
let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES;
let d = f32(load_src0_f16_at(bbase + 208u));
let d = f32(load_f16_at(&src0, bbase + 208u));
let ql1_u32 = load_src0_u32_at(bbase + q_offset_l);
let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u);
let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h);
let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte);
let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u);
let ql1_u32 = load_u32_at(&src0, bbase + q_offset_l);
let ql2_u32 = load_u32_at(&src0, bbase + q_offset_l + 32u);
let qh_u32 = load_u32_at(&src0, bbase + 128u + q_offset_h);
let sc_u32_0 = load_u32_at(&src0, bbase + sc_base_byte);
let sc_u32_1 = load_u32_at(&src0, bbase + sc_base_byte + 4u);
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);

View File

@@ -107,7 +107,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx]));
#endif
#ifdef EXP
let res = exp(src[params.offset_src + src_idx]);
let src_f32 = f32(src[params.offset_src + src_idx]);
let res = TYPE(exp(src_f32));
#endif
#ifdef LOG
let res = TYPE(log(f32(src[params.offset_src + src_idx])));
@@ -161,7 +162,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0));
#endif
#ifdef EXPM1
let res = exp(src[params.offset_src + src_idx]) - 1.0;
let src_f32 = f32(src[params.offset_src + src_idx]);
let res = TYPE(exp(src_f32) - 1.0);
#endif
#ifdef FLOOR
let res = floor(src[params.offset_src + src_idx]);
@@ -181,7 +183,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx];
#endif
#ifdef SQRT
let res = sqrt(src[params.offset_src + src_idx]);
let res = TYPE(sqrt(f32(src[params.offset_src + src_idx])));
#endif
#ifdef SIN
let res_f32 = sin(f32(src[params.offset_src + src_idx]));

View File

@@ -4623,17 +4623,18 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t n_embd_head = hparams.n_embd_head_k(i);
const int64_t n_embd_k = hparams.n_embd_k_gqa(i);
const int64_t n_embd_v = hparams.n_embd_v_gqa(i);
const int kv_flags = hparams.has_kv(i) ? 0 : TENSOR_NOT_REQUIRED;
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
// note: use_alternative_attention (v_proj is optional, if it's not present, use k_proj)
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, kv_flags);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v}, TENSOR_NOT_REQUIRED);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head * n_head, n_embd}, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, kv_flags);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, TENSOR_NOT_REQUIRED);

View File

@@ -7265,6 +7265,7 @@ static const ggml_type all_types[] = {
static const ggml_type base_types[] = {
GGML_TYPE_F32, GGML_TYPE_F16,
GGML_TYPE_Q8_0, // for I8MM tests
GGML_TYPE_Q1_0,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1, // for I8MM tests
GGML_TYPE_Q4_K,

View File

@@ -1988,6 +1988,13 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect(message_assist_thoughts)
.run();
// Empty reasoning (budget=0: sampler forces end tag before newline)
tst.test(
"<|channel>thought<channel|>Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.expect(simple_assist_msg("Hello, world!\nWhat's up?", ""))
.run();
// Reasoning and content with reasoning_format = none
tst.test(
"<|channel>thought\nI'm\nthinking<channel|>Hello, world!\nWhat's up?")

View File

@@ -1014,7 +1014,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
model.hf_file = params.hf_file[i];
}
auto download_result = common_download_model(model, params.hf_token);
common_download_opts opts;
opts.bearer_token = params.hf_token;
auto download_result = common_download_model(model, opts);
if (download_result.model_path.empty()) {
fprintf(stderr, "error: failed to download model from HuggingFace\n");
exit(1);

File diff suppressed because one or more lines are too long

View File

@@ -1,5 +1,5 @@
<!--
This is a single file build of the frontend.
This is a static build of the frontend.
It is automatically generated by the build process.
Do not edit this file directly.
To make changes, refer to the "Web UI" section in the README.
@@ -18,7 +18,7 @@
<div style="display: contents">
<script>
{
__sveltekit_1ao0o9h = {
__sveltekit__ = {
base: new URL('.', location).pathname.slice(0, -1)
};

View File

@@ -98,6 +98,7 @@ static void unset_reserved_args(common_preset & preset, bool unset_model_args) {
if (unset_model_args) {
preset.unset_option("LLAMA_ARG_MODEL");
preset.unset_option("LLAMA_ARG_MMPROJ");
preset.unset_option("LLAMA_ARG_ALIAS");
preset.unset_option("LLAMA_ARG_HF_REPO");
}
}

View File

@@ -1,11 +1,11 @@
{
"name": "webui",
"name": "llama-server-webui",
"version": "1.0.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "webui",
"name": "llama-server-webui",
"version": "1.0.0",
"dependencies": {
"@modelcontextprotocol/sdk": "^1.25.1",

View File

@@ -1,5 +1,5 @@
{
"name": "webui",
"name": "llama-server-webui",
"private": true,
"version": "1.0.0",
"type": "module",

View File

@@ -0,0 +1,84 @@
import { readFileSync, writeFileSync, existsSync, readdirSync, copyFileSync } from 'fs';
import { resolve } from 'path';
import type { Plugin } from 'vite';
const GUIDE_FOR_FRONTEND = `
<!--
This is a static build of the frontend.
It is automatically generated by the build process.
Do not edit this file directly.
To make changes, refer to the "Web UI" section in the README.
-->
`.trim();
export function llamaCppBuildPlugin(): Plugin {
return {
name: 'llamacpp:build',
apply: 'build',
closeBundle() {
// Ensure the SvelteKit adapter has finished writing to ../public
setTimeout(() => {
try {
const indexPath = resolve('../public/index.html');
if (!existsSync(indexPath)) return;
let content = readFileSync(indexPath, 'utf-8');
const faviconPath = resolve('static/favicon.svg');
if (existsSync(faviconPath)) {
const faviconContent = readFileSync(faviconPath, 'utf-8');
const faviconBase64 = Buffer.from(faviconContent).toString('base64');
const faviconDataUrl = `data:image/svg+xml;base64,${faviconBase64}`;
content = content.replace(/href="[^"]*favicon\.svg"/g, `href="${faviconDataUrl}"`);
console.log('✓ Inlined favicon.svg as base64 data URL');
}
content = content.replace(/\r/g, '');
content = GUIDE_FOR_FRONTEND + '\n' + content;
content = content.replace(/\/_app\/immutable\/bundle\.[^"]+\.js/g, './bundle.js');
content = content.replace(
/\/_app\/immutable\/assets\/bundle\.[^"]+\.css/g,
'./bundle.css'
);
content = content.replace(/__sveltekit_[a-z0-9]+/g, '__sveltekit__');
writeFileSync(indexPath, content, 'utf-8');
console.log('✓ Updated index.html');
// Copy bundle.*.js -> ../public/bundle.js
const immutableDir = resolve('../public/_app/immutable');
const bundleDir = resolve('../public/_app/immutable/assets');
if (existsSync(immutableDir)) {
const jsFiles = readdirSync(immutableDir).filter((f) => f.match(/^bundle\..+\.js$/));
if (jsFiles.length > 0) {
copyFileSync(resolve(immutableDir, jsFiles[0]), resolve('../public/bundle.js'));
// Normalize __sveltekit_<hash> to __sveltekit__ in bundle.js
const bundleJsPath = resolve('../public/bundle.js');
let bundleJs = readFileSync(bundleJsPath, 'utf-8');
bundleJs = bundleJs.replace(/__sveltekit_[a-z0-9]+/g, '__sveltekit__');
writeFileSync(bundleJsPath, bundleJs, 'utf-8');
console.log(`✓ Copied ${jsFiles[0]} -> bundle.js`);
}
}
// Copy bundle.*.css -> ../public/bundle.css
if (existsSync(bundleDir)) {
const cssFiles = readdirSync(bundleDir).filter((f) => f.match(/^bundle\..+\.css$/));
if (cssFiles.length > 0) {
copyFileSync(resolve(bundleDir, cssFiles[0]), resolve('../public/bundle.css'));
console.log(`✓ Copied ${cssFiles[0]} -> bundle.css`);
}
}
} catch (error) {
console.error('Failed to update index.html:', error);
}
}, 100);
}
};
}

View File

@@ -22,7 +22,8 @@
</p>
{:else}
<p class="text-xs text-muted-foreground">
Press <kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">{modKey} + Enter</kbd> to send,
Press <kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">{modKey} + Enter</kbd> to
send,
<kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">Enter</kbd> for new line
</p>
{/if}

View File

@@ -25,6 +25,9 @@ const config = {
},
alias: {
$styles: 'src/styles'
},
version: {
name: 'llama-server-webui'
}
},

View File

@@ -1,108 +1,33 @@
import tailwindcss from '@tailwindcss/vite';
import { sveltekit } from '@sveltejs/kit/vite';
import { readFileSync, writeFileSync, existsSync, readdirSync, copyFileSync } from 'fs';
import { dirname, resolve } from 'path';
import { fileURLToPath } from 'url';
import { defineConfig, searchForWorkspaceRoot } from 'vite';
import devtoolsJson from 'vite-plugin-devtools-json';
import { storybookTest } from '@storybook/addon-vitest/vitest-plugin';
import { llamaCppBuildPlugin } from './scripts/vite-plugin-llama-cpp-build';
const __dirname = dirname(fileURLToPath(import.meta.url));
const GUIDE_FOR_FRONTEND = `
<!--
This is a single file build of the frontend.
It is automatically generated by the build process.
Do not edit this file directly.
To make changes, refer to the "Web UI" section in the README.
-->
`.trim();
/**
* the maximum size of an embedded asset in bytes,
* e.g. maximum size of embedded font (see node_modules/katex/dist/fonts/*.woff2)
*/
const MAX_ASSET_SIZE = 32000;
/** public/index.html minified flag */
const ENABLE_JS_MINIFICATION = true;
function llamaCppBuildPlugin() {
return {
name: 'llamacpp:build',
apply: 'build' as const,
closeBundle() {
// Ensure the SvelteKit adapter has finished writing to ../public
setTimeout(() => {
try {
const indexPath = resolve('../public/index.html');
if (!existsSync(indexPath)) {
return;
}
let content = readFileSync(indexPath, 'utf-8');
const faviconPath = resolve('static/favicon.svg');
if (existsSync(faviconPath)) {
const faviconContent = readFileSync(faviconPath, 'utf-8');
const faviconBase64 = Buffer.from(faviconContent).toString('base64');
const faviconDataUrl = `data:image/svg+xml;base64,${faviconBase64}`;
content = content.replace(/href="[^"]*favicon\.svg"/g, `href="${faviconDataUrl}"`);
console.log('✓ Inlined favicon.svg as base64 data URL');
}
content = content.replace(/\r/g, '');
content = GUIDE_FOR_FRONTEND + '\n' + content;
content = content.replace(/\/_app\/immutable\/bundle\.[^"]+\.js/g, './bundle.js');
content = content.replace(
/\/_app\/immutable\/assets\/bundle\.[^"]+\.css/g,
'./bundle.css'
);
writeFileSync(indexPath, content, 'utf-8');
console.log('✓ Updated index.html');
// Copy bundle.*.js -> ../public/bundle.js
const immutableDir = resolve('../public/_app/immutable');
const bundleDir = resolve('../public/_app/immutable/assets');
if (existsSync(immutableDir)) {
const jsFiles = readdirSync(immutableDir).filter((f) => f.match(/^bundle\..+\.js$/));
if (jsFiles.length > 0) {
copyFileSync(resolve(immutableDir, jsFiles[0]), resolve('../public/bundle.js'));
console.log(`✓ Copied ${jsFiles[0]} -> bundle.js`);
}
}
// Copy bundle.*.css -> ../public/bundle.css
if (existsSync(bundleDir)) {
const cssFiles = readdirSync(bundleDir).filter((f) => f.match(/^bundle\..+\.css$/));
if (cssFiles.length > 0) {
copyFileSync(resolve(bundleDir, cssFiles[0]), resolve('../public/bundle.css'));
console.log(`✓ Copied ${cssFiles[0]} -> bundle.css`);
}
}
} catch (error) {
console.error('Failed to update index.html:', error);
}
}, 100);
}
};
}
export default defineConfig({
resolve: {
alias: {
'katex-fonts': resolve('node_modules/katex/dist/fonts')
}
},
build: {
assetsInlineLimit: MAX_ASSET_SIZE,
assetsInlineLimit: 32000,
chunkSizeWarningLimit: 3072,
minify: ENABLE_JS_MINIFICATION
minify: true
},
esbuild: {
lineLimit: 500,
minifyIdentifiers: false
},
css: {
preprocessorOptions: {
scss: {
@@ -114,7 +39,9 @@ export default defineConfig({
}
}
},
plugins: [tailwindcss(), sveltekit(), devtoolsJson(), llamaCppBuildPlugin()],
test: {
projects: [
{
@@ -131,6 +58,7 @@ export default defineConfig({
setupFiles: ['./vitest-setup-client.ts']
}
},
{
extends: './vite.config.ts',
test: {
@@ -139,6 +67,7 @@ export default defineConfig({
include: ['tests/unit/**/*.{test,spec}.{js,ts}']
}
},
{
extends: './vite.config.ts',
test: {