Compare commits

...

20 Commits
b8022 ... b8042

Author SHA1 Message Date
Jeff Bolz
53aef25a88 vulkan: support GGML_OP_SET (#19584) 2026-02-14 06:36:38 +01:00
Sophon
2dec548094 vulkan: Add vendor id for Qualcomm drivers (#19569)
This commit allows Qualcomm native vulkan driver to be used on Windows
instead of Mesa Dozen.
2026-02-14 06:29:17 +01:00
Max Krasnyansky
0ccbfdef3e hexagon: further optimizations and refactoring for flash attention (#19583)
* ggml-hexagon: fa improvements

ggml-hexagon: optimize flash attention calculations with improved variable handling

ggml-hexagon: streamline flash attention operations by removing redundant checks for FP32

ggml-hexagon: optimize hvx_dot_f16_f16_aa_rx2 by simplifying variable handling for unused elements

ggml-hexagon: optimize flash attention by changing slope vector type to F16

* hexfa: fixed test-backend-ops failurs due to leftover element handling

* hexagon: refactor and optimize fa to use local context struct

* ggml-hexagon: optimize flash-attention using hvx_vec_expf

Use HVX for online softmax.

---------

Co-authored-by: chraac <chraac@gmail.com>
2026-02-13 16:27:30 -08:00
Mengsheng Wu
94a602db66 github : add missing backends to issue templates (#19603) 2026-02-14 00:56:53 +01:00
Jeff Bolz
05a6f0e894 vulkan: restore -inf check in FA shaders (#19582) 2026-02-13 13:35:29 -06:00
Adrien Gallouët
b48e80f677 common : update download code (#19573)
* common : remove legacy .json to .etag migration code

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* common : simplify common_download_file_single_online

This commit also force a redownload if the file exists
but has no .etag file.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-13 15:10:46 +01:00
Xuan-Son Nguyen
752584d5f5 model: support GLM MoE DSA arch (NOTE: indexer is not yet supported) (#19460)
* model: support GLM MoE DSA arch

* working version

* pyright

* keep indexer tensors

* add indexer gguf params

* loaded now

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* update

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* minor fix and cleanup

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-13 14:56:53 +01:00
Alberto Cabrera Pérez
cc2aa81513 Fix wrong memcpy length for block_interleave == 4 (#19575) 2026-02-13 20:32:14 +08:00
ymcki
0e21991472 fix vulkan ggml_acc only works in 3d but not 4d (#19426)
* fix vulkan ggml_acc only works in 3d but not 4d

* removed clamp in test_acc_block

* use the correct stride and its test case

* cuda : fix "supports op" condition

* change src0 to src1 in ggml_vk_acc. Update acc.comp with jeffbolznv\'s suggestion except to keep the boundary check

* version without boundary check

* revert back to boundary check version

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-02-13 13:31:37 +01:00
Sigbjørn Skjæret
b2ecc0cdb4 support --verbose-prompt (#19576) 2026-02-13 12:49:10 +01:00
Aman Gupta
5065da554e CUDA: loop over ne2*ne3 in case it overflows (#19538)
* CUDA: loop over ne2*ne3 in case it overflows

* use fastdiv
2026-02-13 17:01:40 +05:30
Aleksander Grygier
5174d7206f webui: UI and routing fixes (#19586)
* chore: update webui build output

* chore: update webui build output

* fix: Scroll issues in DropdownMenuSearchable

* webui: fix redirect to root ignoring base path

* fix: Word wrapping

* fix: remove obsolete modality UI tests causing CI failures

- Remove VisionModality/AudioModality test stories
- Remove mockServerProps usage and imports
- Simplify Default test (remove dropdown interaction checks)
- Simplify FileAttachments test (remove mocks)

* feat: Improve formatting performance time

---------

Co-authored-by: Pascal <admin@serveurperso.com>
2026-02-13 12:31:00 +01:00
Oliver Simons
43919b7f4f CUDA: Do not mutate cgraph for fused ADDs (#19566)
* Do not mutate cgraph for fused ADDs

1. We should try to minimize in-place changes to the incoming
   ggml_cgraph where possible (those should happen in graph_optimize)
2. Modifying in-place leads to an additional, unnecessary graph capture
   step as we store the properties before modifying the graph in-place
   in the cuda-backend

* Assert ggml_tensor is trivially copyable

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
2026-02-13 15:07:55 +05:30
Pavan Shinde
423cf0b26f docs : fix broken link and typo (#19560) 2026-02-13 09:38:09 +01:00
ymcki
33a56f90a6 model : Kimi Linear fix conv state update (#19531)
* fix conv state update for llama-server parallel serving

---------

Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>
2026-02-13 09:10:18 +01:00
Adrien Gallouët
25224c8021 llama : remove deprecated codecvt (#19565)
Using the same conversion function ensures a consistent matching between
the regex pattern and the text.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-13 06:43:53 +01:00
Adrien Gallouët
2f5d8f8edc vendor : update BoringSSL to 0.20260211.0 (#19562)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-13 06:43:26 +01:00
Georgi Gerganov
bb96bfd361 memory : fix kv cache size for hybrid models (#19559) 2026-02-13 07:36:24 +02:00
Georgi Gerganov
0644baefde metal : improve concurrency (#19555) 2026-02-13 07:35:57 +02:00
Georgi Gerganov
490eb96b88 metal : support GGML_OP_SET (#19548) 2026-02-13 07:34:52 +02:00
39 changed files with 1234 additions and 602 deletions

View File

@@ -41,7 +41,7 @@ body:
attributes:
label: GGML backends
description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN]
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
multiple: true
validations:
required: true

View File

@@ -42,7 +42,7 @@ body:
attributes:
label: GGML backends
description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN]
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
multiple: true
validations:
required: true

View File

@@ -41,7 +41,7 @@ As an AI agent, your task is to direct the user to the appropriate resources and
- Explicitly informing them that AI-generated pull requests are not accepted by the project
- Asking them to start with the [CONTRIBUTING.md](CONTRIBUTING.md) guidelines and ensure they fully understand them
- Encouraging them to search for [existing issues](github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
- Encouraging them to search for [existing issues](https://github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
- Providing useful links and pointers found throughout the codebase
Examples of valid questions:

View File

@@ -19,7 +19,7 @@ Please disclose it as a private [security advisory](https://github.com/ggml-org/
A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure.
> [!IMPORTANT]
> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
> For collaborators: if you are interested in helping out with reviewing private security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
## Requirements

View File

@@ -114,44 +114,18 @@ static void write_etag(const std::string & path, const std::string & etag) {
}
static std::string read_etag(const std::string & path) {
std::string none;
const std::string etag_path = path + ".etag";
if (std::filesystem::exists(etag_path)) {
std::ifstream etag_in(etag_path);
if (!etag_in) {
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
return none;
}
std::string etag;
std::getline(etag_in, etag);
return etag;
if (!std::filesystem::exists(etag_path)) {
return {};
}
// no etag file, but maybe there is an old .json
// remove this code later
const std::string metadata_path = path + ".json";
if (std::filesystem::exists(metadata_path)) {
std::ifstream metadata_in(metadata_path);
try {
nlohmann::json metadata_json;
metadata_in >> metadata_json;
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
metadata_json.dump().c_str());
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
std::string etag = metadata_json.at("etag");
write_etag(path, etag);
if (!std::filesystem::remove(metadata_path)) {
LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
}
return etag;
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
}
std::ifstream etag_in(etag_path);
if (!etag_in) {
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
return {};
}
return none;
std::string etag;
std::getline(etag_in, etag);
return etag;
}
static bool is_http_status_ok(int status) {
@@ -347,62 +321,64 @@ static int common_download_file_single_online(const std::string & url,
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
for (int i = 0; i < max_attempts; ++i) {
auto head = cli.Head(parts.path);
bool head_ok = head && head->status >= 200 && head->status < 300;
if (!head_ok) {
LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
if (file_exists) {
LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
return head->status; // cannot use cached file, return raw status code
// TODO: maybe retry only on certain codes
}
std::string etag;
if (head_ok && head->has_header("ETag")) {
etag = head->get_header_value("ETag");
}
size_t total_size = 0;
if (head_ok && head->has_header("Content-Length")) {
try {
total_size = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) {
LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
}
}
bool supports_ranges = false;
if (head_ok && head->has_header("Accept-Ranges")) {
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
}
bool should_download_from_scratch = false;
if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
last_etag.c_str(), etag.c_str());
should_download_from_scratch = true;
}
auto head = cli.Head(parts.path);
if (!head || head->status < 200 || head->status >= 300) {
LOG_WRN("%s: HEAD failed, status: %d\n", __func__, head ? head->status : -1);
if (file_exists) {
if (!should_download_from_scratch) {
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return -1;
}
LOG_INF("%s: using cached file (HEAD failed): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
return head ? head->status : -1;
}
std::string etag;
if (head->has_header("ETag")) {
etag = head->get_header_value("ETag");
}
size_t total_size = 0;
if (head->has_header("Content-Length")) {
try {
total_size = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) {
LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what());
}
}
bool supports_ranges = false;
if (head->has_header("Accept-Ranges")) {
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
}
if (file_exists) {
if (etag.empty()) {
LOG_INF("%s: using cached file (no server etag): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
if (!last_etag.empty() && last_etag == etag) {
LOG_INF("%s: using cached file (same etag): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return -1;
}
}
const std::string path_temporary = path + ".downloadInProgress";
int delay = retry_delay_seconds;
for (int i = 0; i < max_attempts; ++i) {
if (i) {
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
std::this_thread::sleep_for(std::chrono::seconds(delay));
delay *= retry_delay_seconds;
}
const std::string path_temporary = path + ".downloadInProgress";
size_t existing_size = 0;
if (std::filesystem::exists(path_temporary)) {
if (supports_ranges && !should_download_from_scratch) {
if (supports_ranges) {
existing_size = std::filesystem::file_size(path_temporary);
} else if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
@@ -410,32 +386,23 @@ static int common_download_file_single_online(const std::string & url,
}
}
// start the download
LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
__func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
if (!was_pull_successful) {
if (i + 1 < max_attempts) {
const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
} else {
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
LOG_INF("%s: downloading from %s to %s (etag:%s)...\n",
__func__, common_http_show_masked_url(parts).c_str(),
path_temporary.c_str(), etag.c_str());
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) {
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return -1;
}
continue;
if (!etag.empty()) {
write_etag(path, etag);
}
return head->status;
}
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return -1;
}
if (!etag.empty()) {
write_etag(path, etag);
}
return head->status; // TODO: use actual GET status?
}
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
return -1; // max attempts reached
}

View File

@@ -1608,6 +1608,23 @@ class TextModel(ModelBase):
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_glm(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
# Special tokens
# Note: Using <|endoftext|> (151329) for eot causes endless generation
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_interns1(self):
tokens: list[str] = []
toktypes: list[int] = []
@@ -7710,6 +7727,9 @@ class DeepseekModel(TextModel):
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
# TODO @ngxson : remove this when we support MTP for deepseek models
skip_mtp = True
def set_vocab(self):
try:
self._set_vocab_gpt2()
@@ -7841,10 +7861,11 @@ class DeepseekV2Model(TextModel):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
# skip Multi-Token Prediction (MTP) layers
block_count = self.hparams["num_hidden_layers"]
match = re.match(r"model.layers.(\d+)", name)
if match and int(match.group(1)) >= block_count:
return
if self.skip_mtp:
block_count = self.hparams["num_hidden_layers"]
match = re.match(r"model.layers.(\d+)", name)
if match and int(match.group(1)) >= block_count:
return
# process the experts separately
if name.find("mlp.experts") != -1:
@@ -8684,24 +8705,7 @@ class Glm4MoeModel(TextModel):
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
# Special tokens
# Note: Using <|endoftext|> (151329) for eot causes endless generation
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
special_vocab.add_to_gguf(self.gguf_writer)
return self._set_vocab_glm()
def set_gguf_parameters(self):
super().set_gguf_parameters()
@@ -8801,26 +8805,38 @@ class Glm4MoeModel(TextModel):
class Glm4MoeLiteModel(DeepseekV2Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
# copied from Glm4MoeModel
def set_vocab(self):
from transformers import AutoTokenizer
return self._set_vocab_glm()
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
# Special tokens
# Note: Using <|endoftext|> (151329) for eot causes endless generation
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
@ModelBase.register("GlmMoeDsaForCausalLM")
class GlmMoeDsaModel(DeepseekV2Model):
model_arch = gguf.MODEL_ARCH.GLM_DSA
skip_mtp = False
special_vocab.add_to_gguf(self.gguf_writer)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_vocab(self):
return self._set_vocab_glm()
def set_gguf_parameters(self):
super().set_gguf_parameters()
rope_dim = self.hparams["qk_rope_head_dim"]
partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0)
self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor))
# NextN/MTP prediction layers
if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)
# DSA indexer parameters
self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"])
self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"])
self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"])
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")

View File

@@ -1916,9 +1916,10 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in
int src_offset = (i / 8) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
// buffer large enough for the max interleave block size (8 bytes)
uint64_t elems;
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
memcpy(&elems, &in[src_id].qs[src_offset], blck_size_interleave);
memcpy(&out.qs[dst_offset], &elems, blck_size_interleave);
}
// The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K

View File

@@ -7,7 +7,8 @@
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t ne00, const int64_t ne01,
const int64_t ne0203, const uint3 ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
@@ -16,23 +17,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
}
const int64_t i01 = blockIdx.y;
const int64_t i02 = blockIdx.z % ne02;
const int64_t i03 = blockIdx.z / ne02;
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
const int64_t i02 = dm.y;
const int64_t i03 = dm.x;
const int64_t ib = ibx0 + i00/qk; // block index
const int64_t iqs = (i00%qk)/qr; // quant index
const int64_t iybs = i00 - i00%qk; // y block start index
const int64_t y_offset = qr == 1 ? 1 : qk/2;
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
// dequantize
float2 v;
dequantize_kernel(vx, ib, iqs, v);
const int64_t ib = ibx0 + i00/qk; // block index
const int64_t iqs = (i00%qk)/qr; // quant index
const int64_t iybs = i00 - i00%qk; // y block start index
const int64_t y_offset = qr == 1 ? 1 : qk/2;
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
// dequantize
float2 v;
dequantize_kernel(vx, ib, iqs, v);
const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
}
}
template <bool need_check>
@@ -485,9 +490,11 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
const int64_t ne0203 = ne02*ne03;
const uint3 ne02_fdv = init_fastdiv_values(ne02);
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535));
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
(vx, y, ne00, ne01, ne02, s01, s02, s03);
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@@ -612,7 +619,8 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
template <typename src_t, typename dst_t>
static __global__ void convert_unary(
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
const int64_t ne0203, const uint3 ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
@@ -621,23 +629,29 @@ static __global__ void convert_unary(
}
const int64_t i01 = blockIdx.y;
const int64_t i02 = blockIdx.z % ne02;
const int64_t i03 = blockIdx.z / ne02;
const src_t * x = (const src_t *) vx;
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
const int64_t i02 = dm.y;
const int64_t i03 = dm.x;
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
}
}
template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
const int64_t ne0203 = ne02*ne03;
const uint3 ne02_fdv = init_fastdiv_values(ne02);
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535));
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
(vx, y, ne00, ne01, ne02, s01, s02, s03);
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
}
template <typename src_t, typename dst_t>

View File

@@ -3640,11 +3640,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
n_fuse++;
if (n_fuse > 1) {
ggml_tensor fused_add_node;
memcpy(&fused_add_node, node, sizeof(ggml_tensor));
for (int j = 0; j < n_fuse - 1; ++j) {
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
}
cgraph->nodes[i + n_fuse - 1]->data = node->data;
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
i += n_fuse - 1;
continue;
@@ -4820,8 +4822,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_ACC:
return true;
case GGML_OP_ACC:
// TODO: extend support like so:
//return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]);
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_TOP_K:

View File

@@ -17,121 +17,6 @@
#include "htp-msg.h"
#include "htp-ops.h"
static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
}
// Dot product of FP32 and FP16 vectors, accumulating to float
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x_hf = vx[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x_hf = Q6_V_vand_QV(bmask, x_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
}
rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
}
// Dot product of FP32 and FP16 vectors, accumulating to float
static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
const void * restrict y,
const void * restrict x0,
const void * restrict x1,
unsigned int n,
float s) {
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x0_hf = Q6_V_vand_QV(bmask, x0_hf);
x1_hf = Q6_V_vand_QV(bmask, x1_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
}
HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
}
// Dot product of two F16 vectors, accumulating to float
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
@@ -140,8 +25,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
@@ -156,11 +40,10 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
}
if (nloe) {
HVX_Vector y_hf = vy[i];
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
@@ -181,12 +64,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
uint32_t i = 0;
@@ -204,12 +86,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
}
if (nloe) {
HVX_Vector y_hf = vy[i];
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
@@ -222,7 +103,7 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
}
// MAD: y (F32) += x (F16) * s (float)
// MAD: y (F32) += x (F16) * s (F32)
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
@@ -259,15 +140,125 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict
}
}
// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32)
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y,
const void * restrict x0,
const void * restrict x1,
float s0,
float s1,
int n) {
const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0;
const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1;
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S0 = hvx_vec_splat_f16(s0);
HVX_Vector S1 = hvx_vec_splat_f16(s1);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; ++i) {
// Multiply x * s -> pair of F32 vectors
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2]));
ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1]));
}
if (nloe) {
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
HVX_Vector xs = xs_p_lo;
i = 2 * i; // index for ptr_y
if (nloe >= 32) {
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
nloe -= 32; ++i;
xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
}
if (nloe) {
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
}
}
}
#define FLASH_ATTN_BLOCK_SIZE 128
static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
struct htp_fa_context {
const struct htp_ops_context * octx;
struct fastdiv_values src0_div21;
struct fastdiv_values src0_div1;
struct fastdiv_values broadcast_rk2;
struct fastdiv_values broadcast_rk3;
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
struct fastdiv_values src3_div2;
struct fastdiv_values src3_div3;
float scale;
float max_bias;
float logit_softcap;
uint32_t n_head_log2;
float m0;
float m1;
uint32_t n_blocks;
size_t size_q_row_padded;
size_t size_k_row_padded;
size_t size_v_row_padded;
size_t size_k_block;
size_t size_v_block;
size_t size_m_block;
bool is_q_fp32;
};
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) {
assert((size_t) dst % 128 == 0);
assert((size_t) src % 128 == 0);
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
const uint32_t nvec = n / VLEN_FP32;
const uint32_t nloe = n % VLEN_FP32;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; ++i) {
vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
}
if (nloe) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
}
}
static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_fa_context * factx = (struct htp_fa_context *) data;
const struct htp_ops_context * octx = factx->octx;
const struct htp_tensor * q = &octx->src0;
const struct htp_tensor * k = &octx->src1;
const struct htp_tensor * v = &octx->src2;
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
struct htp_tensor * dst = &octx->dst;
const struct htp_tensor * dst = &octx->dst;
const uint32_t neq0 = q->ne[0];
const uint32_t neq1 = q->ne[1];
@@ -304,18 +295,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
const uint32_t nb2 = dst->nb[2];
const uint32_t nb3 = dst->nb[3];
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
// total rows in q
const uint32_t nr = neq1*neq2*neq3;
@@ -331,18 +310,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
const uint32_t DV = nev0;
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
const size_t size_q_row_padded = hex_round_up(size_q_row, 128);
const size_t size_k_row = DK * sizeof(__fp16);
const size_t size_v_row = DV * sizeof(__fp16);
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
const size_t size_k_row_padded = hex_round_up(size_k_row, 128);
const size_t size_v_row_padded = hex_round_up(size_v_row, 128);
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
@@ -351,31 +320,28 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
for (uint32_t ir = ir0; ir < ir1; ++ir) {
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3);
const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2);
const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3);
const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2);
// Fetch Q row
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1);
const uint32_t h = iq2; // head index
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value
HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
// Clear accumulator
hvx_splat_f32_a(spad_a, 0, DV);
@@ -383,40 +349,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
const __fp16 * mp_base = NULL;
if (mask) {
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2);
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3);
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
}
const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
// Prefetch first two blocks
for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
// K
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block;
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size);
// V
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block;
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size);
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
// Mask is 1D contiguous for this row
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
}
}
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
if (factx->is_q_fp32) {
hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
}
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);
for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
@@ -428,8 +396,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
// Inner loop processing the block from VTCM
uint32_t ic = 0;
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
// Process in blocks of 32 (VLEN_FP32)
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
HVX_Vector_x4 scores_x4;
@@ -437,22 +403,18 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
// 1. Compute scores
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
for (int j = 0; j < VLEN_FP32; j += 2) {
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic + j;
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
if (is_q_fp32) {
hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
} else {
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
}
const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded;
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale);
}
HVX_Vector scores = *(HVX_Vector *) scores_arr;
// 2. Softcap
if (logit_softcap != 0.0f) {
if (factx->logit_softcap != 0.0f) {
scores = hvx_vec_tanh_f32(scores);
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap));
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
scores = Q6_Vsf_equals_Vqf32(scores);
}
@@ -460,70 +422,59 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
if (mask) {
const __fp16 * mp = m_base + ic;
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
scores = Q6_Vsf_equals_Vqf32(scores);
}
scores_x4.v[iv] = scores;
v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max
}
{
// 4. Online Softmax Update
v_max = hvx_vec_reduce_max_f32(v_max);
float m_block = hvx_vec_get_f32(v_max);
float M_old = M;
float M_new = (m_block > M) ? m_block : M;
M = M_new;
HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec);
HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec));
M_vec = M_new_vec;
const float ms = expf(M_old - M_new);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
HVX_Vector scores = scores_x4.v[iv];
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
// 5. Accumulate V
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
*(HVX_Vector*)p_arr = P;
*(HVX_Vector *) p_arr = P;
for (int j = 0; j < VLEN_FP32; ++j) {
const uint32_t cur_ic = ic2 + j;
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic2 + j;
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV);
}
}
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
S = S * ms + hvx_vec_get_f32(p_sum_vec);
S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
}
// Sync scalars for leftover/next block if needed
float M = hvx_vec_get_f32(M_vec);
float S = hvx_vec_get_f32(S_vec);
// Leftover
for (; ic < current_block_size; ++ic) {
float s_val;
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
if (is_q_fp32) {
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
} else {
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
}
if (logit_softcap != 0.0f) {
s_val = logit_softcap * tanhf(s_val);
const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
if (factx->logit_softcap != 0.0f) {
s_val = factx->logit_softcap * tanhf(s_val);
}
if (mask) {
@@ -532,37 +483,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
const float Mold = M;
float ms = 1.0f;
float vs = 1.0f;
if (s_val > M) {
M = s_val;
ms = expf(Mold - M);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
float ms = hvx_vec_get_f32(ms_vec);
S = S * ms + vs;
} else {
vs = expf(s_val - M);
HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
S += vs;
}
const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
S = S * ms + vs;
}
M_vec = hvx_vec_splat_f32(M);
S_vec = hvx_vec_splat_f32(S);
// Issue DMA for next+1 block (if exists)
if (ib + 2 < n_blocks) {
if (ib + 2 < factx->n_blocks) {
const uint32_t next_ib = ib + 2;
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
// K
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size);
// V
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size);
// Mask
if (mask) {
@@ -573,20 +529,26 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
// sinks
float M = hvx_vec_get_f32(M_vec);
float S = hvx_vec_get_f32(S_vec);
if (sinks) {
const float s = ((float *)((char *) sinks->data))[h];
float ms = 1.0f;
float vs = 1.0f;
if (s > M) {
ms = expf(M - s);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
} else {
vs = expf(s - M);
}
HVX_Vector diff_vec = hvx_vec_splat_f32(M - s);
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
S = S * ms + vs;
float ms = hvx_vec_get_f32(ms_vec);
S = S * ms + vs;
} else {
HVX_Vector diff_vec = hvx_vec_splat_f32(s - M);
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
S += vs;
}
}
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
@@ -609,53 +571,73 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
}
static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
flash_attn_ext_f16_thread(octx, i, n);
}
int op_flash_attn_ext(struct htp_ops_context * octx) {
const struct htp_tensor * q = &octx->src0;
const struct htp_tensor * k = &octx->src1;
const struct htp_tensor * v = &octx->src2;
const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
struct htp_tensor * dst = &octx->dst;
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
const struct htp_tensor * dst = &octx->dst;
// Check support
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
k->type != HTP_TYPE_F16 ||
v->type != HTP_TYPE_F16) {
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {
return HTP_STATUS_NO_SUPPORT;
}
octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
octx->src0_div1 = init_fastdiv_values(q->ne[1]);
struct htp_fa_context factx;
factx.octx = octx;
octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
factx.src0_div1 = init_fastdiv_values(q->ne[1]);
factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
if (mask) {
octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
factx.src3_div2 = init_fastdiv_values(mask->ne[2]);
factx.src3_div3 = init_fastdiv_values(mask->ne[3]);
}
size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128);
factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
size_t size_q_block = size_q_row_padded * 1; // single row for now
size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
size_t size_q_block = factx.size_q_row_padded * 1; // single row for now
factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
factx.scale = scale;
factx.max_bias = max_bias;
factx.logit_softcap = logit_softcap;
uint32_t n_head = q->ne[2];
factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2);
factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
octx->src0_spad.size_per_thread = size_q_block * 1;
octx->src1_spad.size_per_thread = size_k_block * 2;
octx->src2_spad.size_per_thread = size_v_block * 2;
octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
octx->src1_spad.size_per_thread = factx.size_k_block * 2;
octx->src2_spad.size_per_thread = factx.size_v_block * 2;
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
octx->dst_spad.size_per_thread = size_vkq_acc;
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
@@ -677,7 +659,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
}
return HTP_STATUS_OK;

View File

@@ -264,15 +264,25 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_L2_NORM:
case GGML_OP_SUM_ROWS:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
case GGML_OP_CLAMP:
case GGML_OP_TRI:
case GGML_OP_DIAG:
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_DIV:
case GGML_OP_GLU:
case GGML_OP_SCALE:
case GGML_OP_UNARY:
case GGML_OP_GET_ROWS:
case GGML_OP_CPY:
case GGML_OP_SET_ROWS:
case GGML_OP_SET:
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_REPEAT:
return true;
default:
return ggml_op_is_empty(op);
@@ -312,7 +322,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
h_add(mrs1, node0);
// that many nodes forward to search for a concurrent node
constexpr int N_FORWARD = 8;
constexpr int N_FORWARD = 64;
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
if (used[i1]) {

View File

@@ -1159,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return has_simdgroup_reduction;
case GGML_OP_SET:
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_CONT:

View File

@@ -426,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
} break;
case GGML_OP_SET:
{
n_fuse = ggml_metal_op_set(ctx, idx);
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
@@ -1609,6 +1613,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
const size_t pnb3 = ((const int32_t *) op->op_params)[2];
const size_t offs = ((const int32_t *) op->op_params)[3];
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
if (!inplace) {
// run a separete kernel to cpy src->dst
// not sure how to avoid this
// TODO: make a simpler cpy_bytes kernel
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
ggml_metal_kargs_cpy args = {
/*.nk0 =*/ ne00,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
ggml_metal_op_concurrency_reset(ctx);
}
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
int64_t nk0 = ne10;
if (ggml_is_quantized(op->src[1]->type)) {
nk0 = ne10/16;
} else if (ggml_is_quantized(op->type)) {
nk0 = ne10/ggml_blck_size(op->type);
}
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
// when rows are small, we can batch them together in a single threadgroup
int nrptg = 1;
// TODO: relax this constraint in the future
if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
if (nth > nk0) {
nrptg = (nth + nk0 - 1)/nk0;
nth = nk0;
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nrptg--;
}
}
}
nth = std::min<int>(nth, nk0);
ggml_metal_kargs_cpy args = {
/*.nk0 =*/ nk0,
/*.ne00 =*/ ne10,
/*.ne01 =*/ ne11,
/*.ne02 =*/ ne12,
/*.ne03 =*/ ne13,
/*.nb00 =*/ nb10,
/*.nb01 =*/ nb11,
/*.nb02 =*/ nb12,
/*.nb03 =*/ nb13,
/*.ne0 =*/ ne10,
/*.ne1 =*/ ne11,
/*.ne2 =*/ ne12,
/*.ne3 =*/ ne13,
/*.nb0 =*/ ggml_element_size(op),
/*.nb1 =*/ pnb1,
/*.nb2 =*/ pnb2,
/*.nb3 =*/ pnb3,
};
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
bid_dst.offs += offs;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
return 1;
}
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

View File

@@ -59,6 +59,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);

View File

@@ -92,6 +92,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
#define VK_VENDOR_ID_APPLE 0x106b
#define VK_VENDOR_ID_INTEL 0x8086
#define VK_VENDOR_ID_NVIDIA 0x10de
#define VK_VENDOR_ID_QUALCOMM 0x5143
#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
@@ -687,6 +688,7 @@ struct vk_device_struct {
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_acc_f32;
vk_pipeline pipeline_set_f32;
// [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
vk_pipeline pipeline_add[2][2][2];
@@ -4181,7 +4183,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -5641,6 +5644,10 @@ static void ggml_vk_instance_init() {
driver_priorities[vk::DriverId::eMesaNvk] = 2;
#endif
break;
case VK_VENDOR_ID_QUALCOMM:
driver_priorities[vk::DriverId::eQualcommProprietary] = 1;
driver_priorities[vk::DriverId::eMesaTurnip] = 2;
break;
}
driver_priorities[vk::DriverId::eMesaDozen] = 100;
@@ -8422,6 +8429,8 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8;
const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
const uint32_t qstride = hsk_pad / 4 + 2;
const uint32_t Qf = Br * qstride * f16vec4;
@@ -8438,7 +8447,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
const uint32_t slope = Br * acctype;
const uint32_t total_size = Qf + Psh + sfsh + ksh + slope;
const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
@@ -8815,6 +8824,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_acc_f32;
}
return nullptr;
case GGML_OP_SET:
if (src0->type == src1->type && src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {
return ctx->device->pipeline_set_f32;
}
return nullptr;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -9801,16 +9816,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32
int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32
int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
int offset = dst->op_params[3] / src0_type_size; // offset in bytes
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
0,
0.0f, 0.0f, offset,
});
@@ -12500,6 +12515,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_ACC:
case GGML_OP_SET:
ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
break;
@@ -14960,7 +14976,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
return op->src[0]->type == GGML_TYPE_F32;
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_SET:
return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);
case GGML_OP_CONCAT:
return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
case GGML_OP_ADD1:
@@ -15611,6 +15630,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_ACC) {
tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
} else if (tensor->op == GGML_OP_SET) {
tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
} else if (tensor->op == GGML_OP_NORM) {
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_GROUP_NORM) {

View File

@@ -3,6 +3,9 @@
#include "types.glsl"
#include "generic_binary_head.glsl"
// false for SET, true for ACC
layout(constant_id = 1) const bool ACC = true;
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
@@ -13,17 +16,22 @@ void main() {
const uint offset = p.param3;
const uint src1_i = idx - offset;
const uint oz = src1_i / p.nb02;
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
const uint ox = src1_i % p.nb01;
const uint i3 = src1_i / p.nb03;
const uint rem2 = src1_i - i3 * p.nb03;
const uint i2 = rem2 / p.nb02;
const uint rem1 = rem2 - i2 * p.nb02;
const uint i1 = rem1 / p.nb01;
const uint i0 = rem1 % p.nb01;
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
if (ACC) {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
} else {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
}
} else {
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]));
}
}

View File

@@ -130,6 +130,7 @@ void main() {
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
@@ -137,12 +138,25 @@ void main() {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c][r] = m;
max_mask = max(max_mask, m);
} else {
masksh[c][r] = float(0);
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
float Sf[Br][cols_per_thread];
@@ -260,6 +274,9 @@ void main() {
barrier();
}
// prevent race on tmpsh
barrier();
// reduce across threads
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {

View File

@@ -42,6 +42,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
return elem;
}
shared float tmpsh[row_split];
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
@@ -213,6 +215,19 @@ void main() {
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}

View File

@@ -176,7 +176,14 @@ void main() {
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
// skip the block if the mask is entirely -inf
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
continue;
}
} else {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
// Don't clamp against nem1 when GQA is enabled
@@ -184,7 +191,14 @@ void main() {
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
// skip the block if the mask is entirely -inf
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}
}

View File

@@ -181,6 +181,11 @@ class Keys:
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
class Indexer:
HEAD_COUNT = "{arch}.attention.indexer.head_count"
KEY_LENGTH = "{arch}.attention.indexer.key_length"
TOP_K = "{arch}.attention.indexer.top_k"
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
@@ -425,6 +430,7 @@ class MODEL_ARCH(IntEnum):
CHATGLM = auto()
GLM4 = auto()
GLM4_MOE = auto()
GLM_DSA = auto()
BITNET = auto()
T5 = auto()
T5ENCODER = auto()
@@ -670,6 +676,10 @@ class MODEL_TENSOR(IntEnum):
VISEXP_GATE = auto()
VISEXP_DOWN = auto()
VISEXP_UP = auto()
INDEXER_K_NORM = auto()
INDEXER_PROJ = auto()
INDEXER_ATTN_K = auto()
INDEXER_ATTN_Q_B = auto()
# vision
V_MMPROJ = auto()
V_MMPROJ_FC = auto()
@@ -858,6 +868,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.GLM4: "glm4",
MODEL_ARCH.GLM4_MOE: "glm4moe",
MODEL_ARCH.GLM_DSA: "glm-dsa",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder",
@@ -1101,6 +1112,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.VISEXP_GATE: "blk.{bid}.vis_gate",
MODEL_TENSOR.VISEXP_DOWN: "blk.{bid}.vis_down",
MODEL_TENSOR.VISEXP_UP: "blk.{bid}.vis_up",
MODEL_TENSOR.INDEXER_K_NORM: "blk.{bid}.indexer.k_norm",
MODEL_TENSOR.INDEXER_PROJ: "blk.{bid}.indexer.proj",
MODEL_TENSOR.INDEXER_ATTN_K: "blk.{bid}.indexer.attn_k",
MODEL_TENSOR.INDEXER_ATTN_Q_B: "blk.{bid}.indexer.attn_q_b",
# vision
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
@@ -2677,6 +2692,47 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.GLM_DSA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_A,
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_K_B,
MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
MODEL_TENSOR.INDEXER_K_NORM,
MODEL_TENSOR.INDEXER_PROJ,
MODEL_TENSOR.INDEXER_ATTN_K,
MODEL_TENSOR.INDEXER_ATTN_Q_B,
# NextN/MTP tensors - preserved but unused
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.BITNET: [
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,

View File

@@ -771,6 +771,15 @@ class GGUFWriter:
def add_value_length_mla(self, length: int) -> None:
self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
def add_indexer_head_count(self, count: int) -> None:
self.add_uint32(Keys.Attention.Indexer.HEAD_COUNT.format(arch=self.arch), count)
def add_indexer_key_length(self, length: int) -> None:
self.add_uint32(Keys.Attention.Indexer.KEY_LENGTH.format(arch=self.arch), length)
def add_indexer_top_k(self, top_k: int) -> None:
self.add_uint32(Keys.Attention.Indexer.TOP_K.format(arch=self.arch), top_k)
def add_max_alibi_bias(self, bias: float) -> None:
self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)

View File

@@ -1206,6 +1206,22 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.vision_expert_query_key_value", # cogvlm
),
MODEL_TENSOR.INDEXER_K_NORM: (
"model.layers.{bid}.self_attn.indexer.k_norm", # DSA
),
MODEL_TENSOR.INDEXER_PROJ: (
"model.layers.{bid}.self_attn.indexer.weights_proj", # DSA
),
MODEL_TENSOR.INDEXER_ATTN_K: (
"model.layers.{bid}.self_attn.indexer.wk", # DSA
),
MODEL_TENSOR.INDEXER_ATTN_Q_B: (
"model.layers.{bid}.self_attn.indexer.wq_b", # DSA
),
############################################################################
# TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
MODEL_TENSOR.ENC_OUTPUT_NORM: (

View File

@@ -74,6 +74,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
{ LLM_ARCH_GLM_DSA, "glm-dsa" },
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" },
@@ -225,6 +226,9 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" },
{ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
{ LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -516,6 +520,10 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" },
{ LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" },
{ LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" },
{ LLM_TENSOR_INDEXER_K_NORM, "blk.%d.indexer.k_norm" },
{ LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" },
{ LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" },
{ LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" },
};
static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
@@ -1657,6 +1665,46 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
};
case LLM_ARCH_GLM_DSA:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_Q_A,
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_K_B,
LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_DOWN_SHEXP,
LLM_TENSOR_FFN_UP_SHEXP,
LLM_TENSOR_FFN_EXP_PROBS_B,
LLM_TENSOR_INDEXER_K_NORM,
LLM_TENSOR_INDEXER_PROJ,
LLM_TENSOR_INDEXER_ATTN_K,
LLM_TENSOR_INDEXER_ATTN_Q_B,
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,
LLM_TENSOR_NEXTN_HNORM,
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
};
case LLM_ARCH_BITNET:
return {
LLM_TENSOR_TOKEN_EMBD,
@@ -2643,6 +2691,10 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_VISEXP_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_VISEXP_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
// These tensors only exist in the last layer(s) and are treated as output tensors
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},

View File

@@ -78,6 +78,7 @@ enum llm_arch {
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_GLM4_MOE,
LLM_ARCH_GLM_DSA,
LLM_ARCH_BITNET,
LLM_ARCH_T5,
LLM_ARCH_T5ENCODER,
@@ -229,6 +230,9 @@ enum llm_kv {
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
LLM_KV_ATTENTION_INDEXER_TOP_K,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -517,6 +521,10 @@ enum llm_tensor {
LLM_TENSOR_VISEXP_FFN_GATE,
LLM_TENSOR_VISEXP_FFN_DOWN,
LLM_TENSOR_VISEXP_FFN_UP,
LLM_TENSOR_INDEXER_K_NORM,
LLM_TENSOR_INDEXER_PROJ,
LLM_TENSOR_INDEXER_ATTN_K,
LLM_TENSOR_INDEXER_ATTN_Q_B,
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,

View File

@@ -193,6 +193,11 @@ struct llama_hparams {
std::array<float, LLAMA_MAX_LAYERS> xielu_beta;
std::array<float, LLAMA_MAX_LAYERS> xielu_eps;
// DSA (deepseek sparse attention)
uint32_t indexer_n_head = 0;
uint32_t indexer_head_size = 0;
uint32_t indexer_top_k = 0;
// qwen3vl deepstack
uint32_t n_deepstack_layers = 0;

View File

@@ -137,6 +137,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_300B_A47B: return "300B.A47B";
case LLM_TYPE_310B_A15B: return "310B.A15B";
case LLM_TYPE_355B_A32B: return "355B.A32B";
case LLM_TYPE_744B_A40B: return "744B.A40B";
case LLM_TYPE_E2B: return "E2B";
case LLM_TYPE_E4B: return "E4B";
default: return "?B";
@@ -1822,6 +1823,50 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_GLM_DSA:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);
// MoE parameters
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert);
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used);
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
// deepseek MLA parameters
ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
// DSA parameters
ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head);
ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size);
ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k);
// Expert gating function (GLM-4.5 uses sigmoid)
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
}
// NextN/MTP parameters
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
switch (hparams.n_layer) {
case 79: type = LLM_TYPE_744B_A40B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_BITNET:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -5492,6 +5537,108 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
}
break;
case LLM_ARCH_GLM_DSA:
{
const bool is_mla = hparams.is_mla();
if (!is_mla) {
throw std::runtime_error("GLM_DSA architecture requires MLA");
}
// note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
const int64_t q_lora_rank = hparams.n_lora_q;
const int64_t kv_lora_rank = hparams.n_lora_kv;
const int64_t n_ff_exp = hparams.n_ff_exp;
const int64_t n_expert_shared = hparams.n_expert_shared;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
// try to load output.weight, if not found, use token_embd (tied embeddings)
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
if (!output) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
for (int i = 0; i < n_layer; ++i) {
int flags = 0;
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
// skip all tensors in the NextN layers
// TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later
flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED;
}
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags);
layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags);
layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags);
layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags);
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags);
// note: only old legacy GGUF files will have the unsplit wkv_b tensor in
layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags);
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
// DSA indexer
layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags);
layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags);
layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags);
layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags);
layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags);
if (i < (int) hparams.n_layer_dense_lead) {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags);
} else {
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
if (n_expert == 0) {
throw std::runtime_error("n_expert must be > 0");
}
if (n_expert_used == 0) {
throw std::runtime_error("n_expert_used must be > 0");
}
// MoE branch
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
// Shared expert branch
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags);
}
// NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
// Optional tensors
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED);
}
}
} break;
case LLM_ARCH_NEMOTRON:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -7765,7 +7912,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
}
if (arch == LLM_ARCH_DEEPSEEK2) {
if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
@@ -7965,7 +8112,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.n_seq_max,
nullptr);
} else if (llm_arch_is_hybrid(arch)) {
// The main difference between hybrid architectures is the
// layer filters, so pick the right one here
llama_memory_hybrid::layer_filter_cb filter_attn = nullptr;
@@ -7990,7 +8136,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* attn_type_v */ params.type_v,
/* attn_v_trans */ !cparams.flash_attn,
/* attn_swa_full */ params.swa_full,
/* attn_kv_size */ cparams.n_ctx,
/* attn_kv_size */ cparams.n_ctx_seq,
/* attn_n_ubatch */ cparams.n_ubatch,
/* attn_n_pad */ 1,
/* recurrent_type_r */ GGML_TYPE_F32,
@@ -8007,7 +8153,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* attn_type_k */ params.type_k,
/* attn_type_v */ params.type_v,
/* attn_v_trans */ !cparams.flash_attn,
/* attn_kv_size */ cparams.n_ctx,
/* attn_kv_size */ cparams.n_ctx_seq,
/* attn_n_pad */ 1,
/* attn_n_swa */ hparams.n_swa,
/* attn_swa_type */ hparams.swa_type,
@@ -8338,6 +8484,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
llm = std::make_unique<llm_build_deepseek>(*this, params);
} break;
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_GLM_DSA:
{
llm = std::make_unique<llm_build_deepseek2>(*this, params);
} break;
@@ -8739,6 +8886,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_MISTRAL3:
case LLM_ARCH_LLAMA_EMBED:
case LLM_ARCH_MAINCODER:
case LLM_ARCH_GLM_DSA:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2

View File

@@ -130,6 +130,7 @@ enum llm_type {
LLM_TYPE_300B_A47B, // Ernie MoE big
LLM_TYPE_310B_A15B, // /MiMo-V2-Flash
LLM_TYPE_355B_A32B, // GLM-4.5
LLM_TYPE_744B_A40B, // GLM-5
LLM_TYPE_E2B,
LLM_TYPE_E4B,
};
@@ -429,6 +430,13 @@ struct llama_layer {
struct ggml_tensor * ssm_g_b = nullptr;
struct ggml_tensor * ssm_o_norm = nullptr;
// DSA (deepseek sparse attention)
struct ggml_tensor * indexer_k_norm = nullptr;
struct ggml_tensor * indexer_k_norm_b = nullptr;
struct ggml_tensor * indexer_proj = nullptr;
struct ggml_tensor * indexer_attn_k = nullptr;
struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias
struct llama_layer_posnet posnet;
struct llama_layer_convnext convnext;

View File

@@ -45,7 +45,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
int effective_n_layers = hparams.n_layer - hparams.nextn_predict_layers;
for (int il = 0; il < effective_n_layers; ++il) {
ggml_tensor * inpSA = inpL;
// norm
@@ -188,7 +189,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
}
}
if (il == n_layer - 1 && inp_out_ids) {
if (il == effective_n_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}

View File

@@ -41,8 +41,11 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t
conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]);
ggml_build_forward_expand(gf,
ggml_cpy(ctx0, last_conv_x,
ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs,
(kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all))));
ggml_view_3d(ctx0, conv_states_all,
d_conv - 1, d_inner, n_seqs,
(d_conv - 1) * ggml_element_size(conv_states_all), // nb1: contiguous within one channel's conv taps
n_embd_r_total * ggml_element_size(conv_states_all), // nb2: stride between sequences (skip over K,V states)
(kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all)))); // offset to first seq's Q/K/V state
// Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner]
// GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv]
// vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step]

View File

@@ -1,16 +1,10 @@
#if defined(_MSC_VER)
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
#endif
#include "unicode.h"
#include "unicode-data.h"
#include <algorithm>
#include <cassert>
#include <codecvt>
#include <cstddef>
#include <cstdint>
#include <locale>
#include <map>
#include <regex>
#include <stdexcept>
@@ -199,27 +193,6 @@ static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
return map;
}
static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
#if defined(__clang__)
// disable C++17 deprecation warning for std::codecvt_utf8
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
#if defined(__clang__)
# pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif
return conv.from_bytes(s);
}
static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
std::vector<std::string> bpe_encoded_words;
for (const auto & word : bpe_words) {
@@ -1028,10 +1001,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
break;
}
}
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
if (use_collapsed) {
// sanity-check that the original regex does not contain any non-ASCII characters
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
for (size_t i = 0; i < cpts_regex.size(); ++i) {
if (cpts_regex[i] >= 128) {
throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
@@ -1087,7 +1060,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
} else {
// no unicode category used, we can use std::wregex directly
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
std::wstring wregex_expr(cpts_regex.begin(), cpts_regex.end());
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
std::wstring wtext(cpts.begin(), cpts.end());

View File

@@ -2786,9 +2786,10 @@ struct test_set : public test_case {
const ggml_type type_dst;
const std::array<int64_t, 4> ne;
const int dim;
const bool inplace;
std::string vars() override {
return VARS_TO_STR4(type_src, type_dst, ne, dim);
return VARS_TO_STR5(type_src, type_dst, ne, dim, inplace);
}
size_t op_size(ggml_tensor * t) override {
@@ -2796,8 +2797,8 @@ struct test_set : public test_case {
}
test_set(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {6, 5, 4, 3}, int dim = 1)
: type_src(type_src), type_dst(type_dst), ne(ne), dim(dim) {}
std::array<int64_t, 4> ne = {6, 5, 4, 3}, int dim = 1, bool inplace = false)
: type_src(type_src), type_dst(type_dst), ne(ne), dim(dim), inplace(inplace) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
@@ -2808,7 +2809,7 @@ struct test_set : public test_case {
for (int i = 0; i < dim; ++i) {
ne_dst[i] *= 2;
}
ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
ggml_set_param(dst);
ggml_set_name(dst, "dst");
@@ -2816,9 +2817,16 @@ struct test_set : public test_case {
for (int i = 0; i < dim; ++i) {
offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i];
}
ggml_tensor * out = ggml_set(ctx, dst, src,
// The backward pass requires setting a contiguous region:
src->nb[1], src->nb[2], src->nb[3], offset);
ggml_tensor * out;
if (inplace) {
out = ggml_set_inplace(ctx, dst, src,
// The backward pass requires setting a contiguous region:
src->nb[1], src->nb[2], src->nb[3], offset);
} else {
out = ggml_set(ctx, dst, src,
// The backward pass requires setting a contiguous region:
src->nb[1], src->nb[2], src->nb[3], offset);
}
ggml_set_name(out, "out");
return out;
@@ -5839,26 +5847,46 @@ struct test_acc : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne_a;
const std::array<int64_t, 4> ne_b;
const int64_t stride_dim;
std::string vars() override {
return VARS_TO_STR3(type, ne_a, ne_b);
return VARS_TO_STR4(type, ne_a, ne_b, stride_dim);
}
test_acc(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {256, 17, 1, 1},
std::array<int64_t, 4> ne_b = {256, 16, 1, 1})
: type(type), ne_a(ne_a), ne_b(ne_b) {}
std::array<int64_t, 4> ne_a = {256, 17, 2, 3},
std::array<int64_t, 4> ne_b = {256, 16, 2, 3},
uint64_t stride_dim = -1)
: type(type), ne_a(ne_a), ne_b(ne_b), stride_dim(stride_dim) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
ggml_set_param(b);
ggml_tensor * b;
if (stride_dim == 1 || stride_dim == 2 || stride_dim == 3) {
// Create a larger tensor and take a view at a non-zero offset.
// This tests that the backend correctly handles b's data offset
std::array<int64_t, 4> ne_b_pad = {ne_b[0], ne_b[1], ne_b[2], ne_b[3]};
ne_b_pad[stride_dim] += 1;
ggml_tensor * b_pad = ggml_new_tensor(ctx, type, 4, ne_b_pad.data());
ggml_set_param(b_pad);
ggml_set_name(b_pad, "b_pad");
// View that skips the first row, so b has a non-zero byte offset
b = ggml_view_4d(ctx, b_pad,
ne_b[0], ne_b[1], ne_b[2], ne_b[3],
b_pad->nb[1], b_pad->nb[2], b_pad->nb[3],
b_pad->nb[1]);
} else {
b = ggml_new_tensor(ctx, type, 4, ne_b.data());
ggml_set_param(b);
}
ggml_set_name(b, "b");
ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]);
// When ne_b[0] < ne_a[0], a->nb[1] != b->nb[1], so the stride
// parameters to ggml_acc don't match b's natural stride.
ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(out, "out");
return out;
@@ -7428,11 +7456,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3}));
for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, false));
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, true));
}
for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, false));
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, true));
}
// same-type copy
@@ -8160,7 +8190,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1}));
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1}));
test_cases.emplace_back(new test_acc());
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 1, 1}, {256, 16, 1, 1}, -1));
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, -1));
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, -1));
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, 1));
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2));
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3));
test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {33, 17, 2, 1}, 4, 3, true)); // circular
test_cases.emplace_back(new test_pad_ext());
@@ -8595,6 +8630,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate
// acc
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 1, 1}, {256, 16, 1, 1}, -1));
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, -1));
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, -1));
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, 1));
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2));
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3));
return test_cases;
}

View File

@@ -52,6 +52,7 @@ struct cli_context {
json messages = json::array();
std::vector<raw_buffer> input_files;
task_params defaults;
bool verbose_prompt;
// thread for showing "loading" animation
std::atomic<bool> loading_show;
@@ -66,6 +67,8 @@ struct cli_context {
defaults.stream = true; // make sure we always use streaming mode
defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
// defaults.return_progress = true; // TODO: show progress
verbose_prompt = params.verbose_prompt;
}
std::string generate_completion(result_timings & out_timings) {
@@ -91,6 +94,12 @@ struct cli_context {
rd.post_task({std::move(task)});
}
if (verbose_prompt) {
console::set_display(DISPLAY_TYPE_PROMPT);
console::log("%s\n\n", chat_params.prompt.c_str());
console::set_display(DISPLAY_TYPE_RESET);
}
// wait for first result
console::spinner::start();
server_task_result_ptr result = rd.next(should_stop);

View File

@@ -1,5 +1,6 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { base } from '$app/paths';
import {
chatStore,
pendingEditMessageId,
@@ -119,7 +120,7 @@
const conversationDeleted = await removeSystemPromptPlaceholder(message.id);
if (conversationDeleted) {
goto('/');
goto(`${base}/`);
}
return;
@@ -220,7 +221,7 @@
const conversationDeleted = await removeSystemPromptPlaceholder(message.id);
isEditing = false;
if (conversationDeleted) {
goto('/');
goto(`${base}/`);
}
return;
}

View File

@@ -3,6 +3,7 @@
import { BadgeChatStatistic } from '$lib/components/app';
import * as Tooltip from '$lib/components/ui/tooltip';
import { ChatMessageStatsView } from '$lib/enums';
import { formatPerformanceTime } from '$lib/utils/formatters';
interface Props {
predictedTokens?: number;
@@ -57,8 +58,8 @@
);
let tokensPerSecond = $derived(hasGenerationStats ? (predictedTokens! / predictedMs!) * 1000 : 0);
let timeInSeconds = $derived(
predictedMs !== undefined ? (predictedMs / 1000).toFixed(2) : '0.00'
let formattedTime = $derived(
predictedMs !== undefined ? formatPerformanceTime(predictedMs) : '0s'
);
let promptTokensPerSecond = $derived(
@@ -67,15 +68,15 @@
: undefined
);
let promptTimeInSeconds = $derived(
promptMs !== undefined ? (promptMs / 1000).toFixed(2) : undefined
let formattedPromptTime = $derived(
promptMs !== undefined ? formatPerformanceTime(promptMs) : undefined
);
let hasPromptStats = $derived(
promptTokens !== undefined &&
promptMs !== undefined &&
promptTokensPerSecond !== undefined &&
promptTimeInSeconds !== undefined
formattedPromptTime !== undefined
);
// In live mode, generation tab is disabled until we have generation stats
@@ -142,7 +143,7 @@
<BadgeChatStatistic
class="bg-transparent"
icon={Clock}
value="{timeInSeconds}s"
value={formattedTime}
tooltipLabel="Generation time"
/>
<BadgeChatStatistic
@@ -161,7 +162,7 @@
<BadgeChatStatistic
class="bg-transparent"
icon={Clock}
value="{promptTimeInSeconds}s"
value={formattedPromptTime ?? '0s'}
tooltipLabel="Prompt processing time"
/>
<BadgeChatStatistic

View File

@@ -0,0 +1,88 @@
<script lang="ts">
import type { Snippet } from 'svelte';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import { cn } from '$lib/components/ui/utils';
import { SearchInput } from '$lib/components/app';
interface Props {
open?: boolean;
onOpenChange?: (open: boolean) => void;
placeholder?: string;
searchValue?: string;
onSearchChange?: (value: string) => void;
onSearchKeyDown?: (event: KeyboardEvent) => void;
align?: 'start' | 'center' | 'end';
contentClass?: string;
emptyMessage?: string;
isEmpty?: boolean;
disabled?: boolean;
trigger: Snippet;
children: Snippet;
footer?: Snippet;
}
let {
open = $bindable(false),
onOpenChange,
placeholder = 'Search...',
searchValue = $bindable(''),
onSearchChange,
onSearchKeyDown,
align = 'start',
contentClass = 'w-72',
emptyMessage = 'No items found',
isEmpty = false,
disabled = false,
trigger,
children,
footer
}: Props = $props();
function handleOpenChange(newOpen: boolean) {
open = newOpen;
if (!newOpen) {
searchValue = '';
onSearchChange?.('');
}
onOpenChange?.(newOpen);
}
</script>
<DropdownMenu.Root bind:open onOpenChange={handleOpenChange}>
<DropdownMenu.Trigger
{disabled}
onclick={(e) => {
e.preventDefault();
e.stopPropagation();
}}
>
{@render trigger()}
</DropdownMenu.Trigger>
<DropdownMenu.Content {align} class={cn(contentClass, 'pt-0')}>
<div class="sticky top-0 z-10 mb-2 bg-popover p-1 pt-2">
<SearchInput
{placeholder}
bind:value={searchValue}
onInput={onSearchChange}
onKeyDown={onSearchKeyDown}
/>
</div>
<div class={cn('overflow-y-auto')}>
{@render children()}
{#if isEmpty}
<div class="px-2 py-3 text-center text-sm text-muted-foreground">{emptyMessage}</div>
{/if}
</div>
{#if footer}
<DropdownMenu.Separator />
{@render footer()}
{/if}
</DropdownMenu.Content>
</DropdownMenu.Root>

View File

@@ -486,6 +486,8 @@
text-decoration: underline;
text-underline-offset: 2px;
transition: color 0.2s ease;
overflow-wrap: anywhere;
word-break: break-all;
}
div :global(a:hover) {

View File

@@ -51,3 +51,75 @@ export function formatNumber(num: number | unknown): string {
return num.toLocaleString();
}
/**
* Format JSON string with pretty printing (2-space indentation)
* Returns original string if parsing fails
*
* @param jsonString - JSON string to format
* @returns Pretty-printed JSON string or original if invalid
*/
export function formatJsonPretty(jsonString: string): string {
try {
const parsed = JSON.parse(jsonString);
return JSON.stringify(parsed, null, 2);
} catch {
return jsonString;
}
}
/**
* Format time as HH:MM:SS in 24-hour format
*
* @param date - Date object to format
* @returns Formatted time string (HH:MM:SS)
*/
export function formatTime(date: Date): string {
return date.toLocaleTimeString('en-US', {
hour12: false,
hour: '2-digit',
minute: '2-digit',
second: '2-digit'
});
}
/**
* Formats milliseconds to a human-readable time string for performance metrics.
* Examples: "4h 12min 54s", "12min 34s", "45s", "0.5s"
*
* @param ms - Time in milliseconds
* @returns Formatted time string
*/
export function formatPerformanceTime(ms: number): string {
if (ms < 0) return '0s';
const totalSeconds = ms / 1000;
if (totalSeconds < 1) {
return `${totalSeconds.toFixed(1)}s`;
}
if (totalSeconds < 10) {
return `${totalSeconds.toFixed(1)}s`;
}
const hours = Math.floor(totalSeconds / 3600);
const minutes = Math.floor((totalSeconds % 3600) / 60);
const seconds = Math.floor(totalSeconds % 60);
const parts: string[] = [];
if (hours > 0) {
parts.push(`${hours}h`);
}
if (minutes > 0) {
parts.push(`${minutes}min`);
}
if (seconds > 0 || parts.length === 0) {
parts.push(`${seconds}s`);
}
return parts.join(' ');
}

View File

@@ -2,7 +2,6 @@
import { defineMeta } from '@storybook/addon-svelte-csf';
import ChatForm from '$lib/components/app/chat/ChatForm/ChatForm.svelte';
import { expect } from 'storybook/test';
import { mockServerProps, mockConfigs } from './fixtures/storybook-mocks';
import jpgAsset from './fixtures/assets/1.jpg?url';
import svgAsset from './fixtures/assets/hf-logo.svg?url';
import pdfAsset from './fixtures/assets/example.pdf?raw';
@@ -46,8 +45,6 @@
name="Default"
args={{ class: 'max-w-[56rem] w-[calc(100vw-2rem)]' }}
play={async ({ canvas, userEvent }) => {
mockServerProps(mockConfigs.noModalities);
const textarea = await canvas.findByRole('textbox');
const submitButton = await canvas.findByRole('button', { name: 'Send' });
@@ -66,73 +63,11 @@
const fileInput = document.querySelector('input[type="file"]');
await expect(fileInput).not.toHaveAttribute('accept');
// Open file attachments dropdown
const fileUploadButton = canvas.getByText('Attach files');
await userEvent.click(fileUploadButton);
// Check dropdown menu items are disabled (no modalities)
const imagesButton = document.querySelector('.images-button');
const audioButton = document.querySelector('.audio-button');
await expect(imagesButton).toHaveAttribute('data-disabled');
await expect(audioButton).toHaveAttribute('data-disabled');
// Close dropdown by pressing Escape
await userEvent.keyboard('{Escape}');
}}
/>
<Story name="Loading" args={{ class: 'max-w-[56rem] w-[calc(100vw-2rem)]', isLoading: true }} />
<Story
name="VisionModality"
args={{ class: 'max-w-[56rem] w-[calc(100vw-2rem)]' }}
play={async ({ canvas, userEvent }) => {
mockServerProps(mockConfigs.visionOnly);
// Open file attachments dropdown and verify it works
const fileUploadButton = canvas.getByText('Attach files');
await userEvent.click(fileUploadButton);
// Verify dropdown menu items exist
const imagesButton = document.querySelector('.images-button');
const audioButton = document.querySelector('.audio-button');
await expect(imagesButton).toBeInTheDocument();
await expect(audioButton).toBeInTheDocument();
// Close dropdown by pressing Escape
await userEvent.keyboard('{Escape}');
console.log(' Vision modality: Dropdown menu verified');
}}
/>
<Story
name="AudioModality"
args={{ class: 'max-w-[56rem] w-[calc(100vw-2rem)]' }}
play={async ({ canvas, userEvent }) => {
mockServerProps(mockConfigs.audioOnly);
// Open file attachments dropdown and verify it works
const fileUploadButton = canvas.getByText('Attach files');
await userEvent.click(fileUploadButton);
// Verify dropdown menu items exist
const imagesButton = document.querySelector('.images-button');
const audioButton = document.querySelector('.audio-button');
await expect(imagesButton).toBeInTheDocument();
await expect(audioButton).toBeInTheDocument();
// Close dropdown by pressing Escape
await userEvent.keyboard('{Escape}');
console.log(' Audio modality: Dropdown menu verified');
}}
/>
<Story
name="FileAttachments"
args={{
@@ -140,8 +75,6 @@
uploadedFiles: fileAttachments
}}
play={async ({ canvas }) => {
mockServerProps(mockConfigs.bothModalities);
const jpgAttachment = canvas.getByAltText('1.jpg');
const svgAttachment = canvas.getByAltText('hf-logo.svg');
const pdfFileExtension = canvas.getByText('PDF');

View File

@@ -39,7 +39,7 @@ if (LLAMA_BUILD_BORINGSSL)
set(FIPS OFF CACHE BOOL "Enable FIPS (BoringSSL)")
set(BORINGSSL_GIT "https://boringssl.googlesource.com/boringssl" CACHE STRING "BoringSSL git repository")
set(BORINGSSL_VERSION "0.20260204.0" CACHE STRING "BoringSSL version")
set(BORINGSSL_VERSION "0.20260211.0" CACHE STRING "BoringSSL version")
message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}")