Compare commits

...

16 Commits

Author SHA1 Message Date
Aleksander Grygier
f486ce9f30 (webui) REFACTOR: UI primitives and polish (#19551)
* webui: UI primitives and polish (non-MCP)

* chore: update webui build output
2026-02-12 12:21:00 +01:00
Aleksander Grygier
38adc7d469 WebUI Architecture Cleanup (#19541)
* webui: architecture foundation (non-MCP core refactors)

* chore: update webui build output
2026-02-12 11:22:27 +01:00
Georgi Gerganov
3b3a948134 metal : update sum_rows kernel to support float4 (#19524) 2026-02-12 11:35:28 +02:00
Mario Limonciello
6845f7f87f Add a workaround for compilation with ROCWMMA_FATTN and gfx9 (#19461)
There is an upstream problem [1] with AMD's LLVM 22 fork and
rocWMMA 2.2.0 causing compilation issues on devices without
native fp16 support (CDNA devices).

The specialized types aren't resolved properly:
```
/opt/rocm/include/rocwmma/internal/mfma_impl.hpp:2549:37: error: ambiguous partial specializations of 'amdgcn_mfma<__half, __half, __half, 16, 16, 16>'
 2549 |             using ARegsT = typename Impl::ARegsT;
```

Add a workaround to explicitly declare the types and cast when
compiling with HIP and ROCWMMA_FATTN [2].  When this is actually
fixed upstream some guards can be used to detect and wrap the
version that has the fix to only apply when necessary.

Link: https://github.com/ROCm/rocm-libraries/issues/4398 [1]
Link: https://github.com/ggml-org/llama.cpp/issues/19269 [2]

Signed-off-by: Mario Limonciello <mario.limonciello@amd.com>
2026-02-12 09:38:35 +01:00
RichardScottOZ
fa16e517a3 server : fix typo in README.md for features list (#19510)
extra l for full
2026-02-12 08:56:25 +01:00
TriDefender
313493de53 docs : update path in snapdragon README.md (#19533)
paths changed so original example didn't work
2026-02-12 08:13:51 +01:00
Max Krasnyansky
b1ff83bbb0 hexagon: further optimization and tuning of matmul and dot kernels (#19407)
* ggml-hexagon: implement 2x2 matmul kernel

* hexmm: implement vec_dot_rx2x2 for Q8_0 and MXFP4

* hexagon: fix editor config failures

* hexagon: refactor matmul ops to use context struct and remove wrappers

Also implement vec_dot_f16 2x2

* hexagon: refactor dyn quantizers to use mmctx

* hexagon: remove mm fastdiv from op_ctx

* hexagon: refactor matmul entry point to reduce code duplication

---------

Co-authored-by: Trivikram Reddy <tamarnat@qti.qualcomm.com>
2026-02-11 23:04:27 -08:00
Adrien Gallouët
4ae1b7517a common : replace deprecated codecvt using parse_utf8_codepoint (#19517)
Signed-off-by: Adrien Gallouët <adrien@gallouet.fr>
2026-02-12 07:27:52 +01:00
lhez
4d3daf80f8 opencl: add general Q6_K mm and Q4_K mv (#19347)
* opencl: add general q6_k mm

* opencl: refine condition for q6_K mm

* opencl: add general q4_K mv

* opencl: fix whitespace
2026-02-11 10:33:13 -08:00
Georgi Gerganov
914dde72ba ggml : unary ops support non-cont src0 + metal F16 unary ops (#19511)
* ggml : unary ops support non-cont src0

* metal : support F16 unary ops + fix ELU
2026-02-11 18:58:43 +02:00
Daniel Bevenius
3136a849db common : remove unused token util functions (#19506)
This commit removes two unused functions `common_lcp` and `common_lcs`.
The last usage of these functions was removed in
Commit 33eff40240 ("server : vision support
via libmtmd") and are no longer used anywhere in the codebase.
2026-02-11 17:41:35 +01:00
AesSedai
e463bbdf65 model: Add Kimi-K2.5 support (#19170)
* Move dequant_model to after the text_config merge
Add new kimi-k2.5 keys to mtmd convert
Update V_MMPROJ tensor mapping for new mm_projector.proj keys
Update V_M_IMP_NORM for new mm_projector.pre_norm key

* Fix a couple of oversights

* Add image support for Kimi-K2.5

* Revert changes to KimiVLForConditionalGeneration

* Fix an assert crash

* Fix permute swapping w / h on accident

* Kimi-K2.5: Use merged QKV for vision

* Kimi-K2.5: pre-convert vision QK to use build_rope_2d

* Kimi-K2.5: support non-interleaved rope for vision

* Kimi-K2.5: fix min / max pixel

* Kimi-K2.5: remove v/o permutes, unnecessary

* Kimi-K2.5: update permute name to match

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Kimi-K2.5: replace build_rope_2d ggml_cont with ggml_view_3d pointers

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-11 16:47:30 +01:00
Daniel Bevenius
53de59f67d build : fix case in dSYMs path for build-macos [no ci] (#19515)
This commit updates an incorrect dSYMs where the the 's' was uppercase
by mistake.

The motivation for fixing this is that this can cause issues on case
sensitive operating systems.

Refs: https://github.com/ggml-org/whisper.cpp/pull/3630
2026-02-11 14:02:29 +01:00
Georgi Gerganov
9ab072ebbe metal : extend l2_norm support for non-cont src0 (#19502) 2026-02-11 14:53:19 +02:00
Johannes Gäßler
ada90bf2ba docs: ban AI for issues and discussions [no CI] (#19512) 2026-02-11 12:49:40 +01:00
Adrien Gallouët
0c1f39a9ae common : improve download error reporting (#19491)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-11 09:27:55 +01:00
104 changed files with 6714 additions and 1036 deletions

View File

@@ -20,7 +20,7 @@ If AI is used to generate any portion of the code, contributors must adhere to t
1. Explicitly disclose the manner in which AI was employed.
2. Perform a comprehensive manual review prior to submitting the pull request.
3. Be prepared to explain every line of code they submitted when asked about it by a maintainer.
4. Using AI to write pull request descriptions or to respond to human reviewers is strictly prohibited.
4. It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...).
For more info, please refer to the [AGENTS.md](AGENTS.md) file.

View File

@@ -534,7 +534,7 @@ xcodebuild -create-xcframework \
-framework $(pwd)/build-ios-device/framework/llama.framework \
-debug-symbols $(pwd)/build-ios-device/dSYMs/llama.dSYM \
-framework $(pwd)/build-macos/framework/llama.framework \
-debug-symbols $(pwd)/build-macos/dSYMS/llama.dSYM \
-debug-symbols $(pwd)/build-macos/dSYMs/llama.dSYM \
-framework $(pwd)/build-visionos/framework/llama.framework \
-debug-symbols $(pwd)/build-visionos/dSYMs/llama.dSYM \
-framework $(pwd)/build-visionos-sim/framework/llama.framework \

View File

@@ -1,7 +1,3 @@
#if defined(_MSC_VER)
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
#endif
#include "ggml.h"
#include "gguf.h"
@@ -9,12 +5,12 @@
#include "log.h"
#include "llama.h"
#include "sampling.h"
#include "unicode.h"
#include <algorithm>
#include <cinttypes>
#include <climits>
#include <cmath>
#include <codecvt>
#include <chrono>
#include <cstdarg>
#include <cstring>
@@ -706,45 +702,28 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
return false;
}
std::u32string filename_utf32;
try {
#if defined(__clang__)
// disable C++17 deprecation warning for std::codecvt_utf8
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
size_t offset = 0;
while (offset < filename.size()) {
utf8_parse_result result = parse_utf8_codepoint(filename, offset);
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
#if defined(__clang__)
# pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif
filename_utf32 = converter.from_bytes(filename);
// If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
// or invalid encodings were encountered. Reject such attempts
std::string filename_reencoded = converter.to_bytes(filename_utf32);
if (filename_reencoded != filename) {
if (result.status != utf8_parse_result::SUCCESS) {
return false;
}
} catch (const std::exception &) {
return false;
}
uint32_t c = result.codepoint;
// Check for forbidden codepoints:
// - Control characters
// - Unicode equivalents of illegal characters
// - UTF-16 surrogate pairs
// - UTF-8 replacement character
// - Byte order mark (BOM)
// - Illegal characters: / \ : * ? " < > |
for (char32_t c : filename_utf32) {
if ((result.bytes_consumed == 2 && c < 0x80) ||
(result.bytes_consumed == 3 && c < 0x800) ||
(result.bytes_consumed == 4 && c < 0x10000)) {
return false;
}
// Check for forbidden codepoints:
// - Control characters
// - Unicode equivalents of illegal characters
// - UTF-16 surrogate pairs
// - UTF-8 replacement character
// - Byte order mark (BOM)
// - Illegal characters: / \ : * ? " < > |
if (c <= 0x1F // Control characters (C0)
|| c == 0x7F // Control characters (DEL)
|| (c >= 0x80 && c <= 0x9F) // Control characters (C1)
@@ -752,6 +731,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|| c == 0x2215 // Division Slash (forward slash equivalent)
|| c == 0x2216 // Set Minus (backslash equivalent)
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|| c > 0x10FFFF // Max Unicode limit
|| c == 0xFFFD // Replacement Character (UTF-8)
|| c == 0xFEFF // Byte Order Mark (BOM)
|| c == ':' || c == '*' // Illegal characters
@@ -762,6 +742,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
// Subdirectories not allowed, reject path separators
return false;
}
offset += result.bytes_consumed;
}
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
@@ -1469,66 +1450,6 @@ void common_batch_add(
batch.n_tokens++;
}
//
// Token utils
//
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
return i;
}
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}
// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();
// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;
// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);
// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}
// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}
// update the previous row for the next iteration
prev_row = curr_row;
}
// return the maximum length of the LCS
return max_length;
}
//
// Vocab utils
//

View File

@@ -779,16 +779,6 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
//
// Token utils
//
// longest common prefix
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
// longet common subsequence
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
//
// Vocab utils
//

View File

@@ -305,7 +305,10 @@ static bool common_pull_file(httplib::Client & cli,
);
if (!res) {
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
LOG_ERR("%s: download failed: %s (status: %d)\n",
__func__,
httplib::to_string(res.error()).c_str(),
res ? res->status : -1);
return false;
}

View File

@@ -160,8 +160,6 @@ class ModelBase:
self.ftype = gguf.LlamaFileType.MOSTLY_F16
logger.info("heuristics unable to detect tensor dtype, defaulting to --outtype f16")
self.dequant_model()
# Configure GGUF Writer
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
@@ -527,6 +525,8 @@ class ModelBase:
return ()
def prepare_tensors(self):
self.dequant_model()
# Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
if self.tensor_map.mapping:
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
@@ -1815,7 +1815,7 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers", "vt_num_hidden_layers"]
has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
@@ -1870,7 +1870,15 @@ class MmprojModel(ModelBase):
preprocessor_config_path = self.dir_model / "preprocessor_config.json"
if preprocessor_config_path.is_file():
with open(preprocessor_config_path, "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
cfg = json.load(f)
# move media_proc_cfg to root level for compat
if "media_proc_cfg" in cfg:
cfg = {
**cfg,
**cfg["media_proc_cfg"],
}
# merge configs
self.preprocessor_config = {**self.preprocessor_config, **cfg}
# prefer processor_config.json if possible
processor_config_path = self.dir_model / "processor_config.json"
@@ -1919,10 +1927,10 @@ class MmprojModel(ModelBase):
self.image_size = self.find_vparam(["image_size"])
self.gguf_writer.add_vision_image_size(self.image_size)
self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size", "vt_hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"]))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "vt_num_attention_heads"]))
# preprocessor config
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
@@ -7695,6 +7703,7 @@ class DeepseekModel(TextModel):
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"KimiVLForConditionalGeneration",
"KimiK25ForConditionalGeneration",
"YoutuForCausalLM",
"YoutuVLForConditionalGeneration",
)
@@ -7813,8 +7822,8 @@ class DeepseekV2Model(TextModel):
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip vision tensors and remove "language_model." for Kimi-VL
if "vision_tower" in name or "multi_modal_projector" in name:
# skip vision tensors and remove "language_model." for Kimi-VL and Kimi-K2.5
if "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name:
return
if name.startswith("siglip2.") or name.startswith("merger."):
return
@@ -11176,6 +11185,103 @@ class KimiVLModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("KimiK25ForConditionalGeneration")
class KimiK25Model(MmprojModel):
"""Kimi-K2.5 with MoonViT3d vision encoder"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None, "Kimi-K2.5 requires vision_config in model config"
self.merge_kernel_size = tuple(self.hparams_vision.get("merge_kernel_size", [2, 2]))
self.patch_size = self.hparams_vision.get("patch_size", 14)
# Set image_size for compatibility with base class
# Use position embedding dimensions as image_size reference
pos_emb_h = self.hparams_vision.get("init_pos_emb_height", 64)
self.hparams_vision["image_size"] = pos_emb_h * self.patch_size
def set_gguf_parameters(self):
# Base class MmprojModel.set_gguf_parameters() already writes:
# - vision_block_count, vision_head_count, vision_embedding_length
# - vision_feed_forward_length, vision_patch_size, image_mean, image_std
# via find_vparam() which handles the vt_* prefixed keys in Kimi-K2.5's config
super().set_gguf_parameters()
assert self.hparams_vision is not None
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIK25)
# Position embedding parameters (for interpolation)
self.gguf_writer.add_uint32("vision.pos_emb_height", self.hparams_vision.get("init_pos_emb_height", 64))
self.gguf_writer.add_uint32("vision.pos_emb_width", self.hparams_vision.get("init_pos_emb_width", 64))
self.gguf_writer.add_uint32("vision.pos_emb_time", self.hparams_vision.get("init_pos_emb_time", 4))
# Projector parameters
self.gguf_writer.add_vision_use_gelu(self.hparams_vision.get("projector_hidden_act", "gelu") == "gelu")
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("projector_ln_eps", 1e-5))
self.gguf_writer.add_vision_projector_scale_factor(self.merge_kernel_size[0])
# Image size limits
# Note: in_patch_limit is for images, in_patch_limit_each_frame is for video (not supported yet)
in_patch_limit = self.preprocessor_config.get("in_patch_limit", 16384)
min_patches = 8 # reasonable minimum
pixels_per_patch = self.patch_size ** 2
self.gguf_writer.add_vision_min_pixels(min_patches * pixels_per_patch)
self.gguf_writer.add_vision_max_pixels(in_patch_limit * pixels_per_patch)
@staticmethod
def permute(weights: Tensor, n_head: int) -> Tensor:
out_dim, in_dim = weights.shape
head_dim = out_dim // n_head
w = weights.reshape(n_head, head_dim // 4, 2, 2, in_dim)
w = w.permute(0, 2, 1, 3, 4)
return w.reshape(out_dim, in_dim)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Only process vision and projector tensors
is_vision = any(x in name for x in ["vision_tower", "mm_projector"])
if not is_vision:
return
assert self.hparams_vision is not None
n_head = self.hparams_vision.get("num_attention_heads", 16)
# Permute Q/K weights/biases from interleaved to split RoPE format
# This allows using build_rope_2d at runtime without post-permutation.
if "wqkv" in name:
out_dim = data_torch.shape[0]
qkv_dim = out_dim // 3
head_dim = qkv_dim // n_head
if "weight" in name:
wq, wk, wv = data_torch[:qkv_dim, :], data_torch[qkv_dim:2 * qkv_dim, :], data_torch[2 * qkv_dim:, :]
wq = self.permute(wq, n_head)
wk = self.permute(wk, n_head)
data_torch = torch.cat([wq, wk, wv], dim=0)
elif "bias" in name:
bq, bk, bv = data_torch[:qkv_dim], data_torch[qkv_dim:2 * qkv_dim], data_torch[2 * qkv_dim:]
bq = bq.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1)
bk = bk.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1)
data_torch = torch.cat([bq, bk, bv], dim=0)
# Temporal embeddings: (T, 1, C) → (T, C)
if "pos_emb.time_weight" in name:
T, _, C = data_torch.shape
data_torch = data_torch.reshape(T, C)
# PatchMergerMLP tensor name mapping
# proj.0.weight → proj.linear_1.weight
# proj.2.weight → proj.linear_2.weight
if "mm_projector.proj.0." in name:
name = name.replace(".proj.0.", ".proj.linear_1.")
elif "mm_projector.proj.2." in name:
name = name.replace(".proj.2.", ".proj.linear_2.")
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("CogVLMForCausalLM")
class CogVLMVisionModel(MmprojModel):

View File

@@ -35,7 +35,7 @@ Adapt below build commands accordingly.
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
```
[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json .
[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json .
[d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon
Preset CMake variables:

View File

@@ -2096,10 +2096,14 @@ static void ggml_compute_forward_gelu_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2113,10 +2117,14 @@ static void ggml_compute_forward_gelu_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2135,10 +2143,14 @@ static void ggml_compute_forward_gelu_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2152,10 +2164,14 @@ static void ggml_compute_forward_gelu_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2276,10 +2292,14 @@ static void ggml_compute_forward_gelu_erf_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2293,10 +2313,14 @@ static void ggml_compute_forward_gelu_erf_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_erf_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2315,10 +2339,14 @@ static void ggml_compute_forward_gelu_erf_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2332,10 +2360,14 @@ static void ggml_compute_forward_gelu_erf_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_erf_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2379,10 +2411,14 @@ static void ggml_compute_forward_gelu_quick_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2396,10 +2432,14 @@ static void ggml_compute_forward_gelu_quick_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_quick_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2418,10 +2458,14 @@ static void ggml_compute_forward_gelu_quick_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2435,10 +2479,14 @@ static void ggml_compute_forward_gelu_quick_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_quick_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2482,10 +2530,14 @@ static void ggml_compute_forward_silu_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2499,10 +2551,14 @@ static void ggml_compute_forward_silu_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_silu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2521,10 +2577,14 @@ static void ggml_compute_forward_silu_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2538,10 +2598,14 @@ static void ggml_compute_forward_silu_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_silu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {

View File

@@ -111,7 +111,7 @@ template <float (*op)(float), typename src0_t, typename dst_t>
static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst));
GGML_TENSOR_UNARY_OP_LOCALS

View File

@@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16(
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
#if defined(GGML_USE_HIP)
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K;
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V;
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, _Float16> frag_c_VKQ;
#else
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
#endif
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16(
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
half2 * VKQ2 = (half2 *) VKQ;
#if defined(GGML_USE_HIP)
const _Float16 * K_h_f16 = reinterpret_cast<const _Float16 *>(K_h);
const _Float16 * V_h_f16 = reinterpret_cast<const _Float16 *>(V_h);
_Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ);
_Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ);
#else
const half * K_h_f16 = K_h;
const half * V_h_f16 = V_h;
half * KQ_f16 = KQ;
half * VKQ_f16 = VKQ;
#endif
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
@@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16(
for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);
}
}
@@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k,
KQ_f16 + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded);
}
}
@@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, wmma::mem_col_major);
}

View File

@@ -64,25 +64,12 @@ struct htp_ops_context {
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01
struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02
struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03
struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00
struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01
struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02
uint32_t flags;
};

File diff suppressed because it is too large Load Diff

View File

@@ -328,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
char base[256];
char name[256];
const char * op_str = "undefined";
int op_num = -1;
switch (op->op) {
case GGML_OP_SUM_ROWS:
op_str = "sum_rows"; break;
case GGML_OP_MEAN:
op_str = "mean"; break;
case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(name, 256, "%s", base);
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d", base, op_num);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.smem = 32*sizeof(float);
if (is_c4) {
res.smem *= 4;
}
res.c4 = is_c4;
return res;
}
@@ -1480,13 +1495,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_met
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_L2_NORM);
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
char base[256];
char name[256];
snprintf(base, 256, "kernel_l2_norm_f32");
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
@@ -1494,6 +1511,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
res.c4 = is_c4;
res.smem = 32*sizeof(float);
return res;

View File

@@ -1019,7 +1019,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_TANH:
@@ -1039,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_EXPM1:
return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
default:
return false;
}
@@ -1086,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 &&

View File

@@ -82,6 +82,7 @@
#define FC_COUNT_EQUAL 1100
#define FC_UNARY 1200
#define FC_BIN 1300
#define FC_SUM_ROWS 1400
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
@@ -118,6 +119,8 @@
#define OP_UNARY_NUM_SOFTPLUS 115
#define OP_UNARY_NUM_EXPM1 116
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
#define OP_SUM_ROWS_NUM_MEAN 11
// kernel argument structs
//
@@ -539,8 +542,21 @@ typedef struct {
typedef struct {
int32_t ne00;
int32_t ne00_4;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float eps;
} ggml_metal_kargs_l2_norm;

View File

@@ -904,6 +904,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
@@ -925,21 +930,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00);
nth = std::min(nth, (int) args.ne00);
const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
@@ -2979,39 +2989,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
float eps;
memcpy(&eps, op->op_params, sizeof(float));
int nth = 32; // SIMD width
ggml_metal_kargs_l2_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb01 =*/ nb01,
/*.eps =*/ eps,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.eps =*/ eps,
};
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00/4);
const size_t smem = pipeline.smem;
const int64_t nrows = ggml_nrows(op->src[0]);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
return 1;
}

View File

@@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
return x*y;
}
static inline float sum(float x) {
return x;
}
static inline float sum(float4 x) {
return x[0] + x[1] + x[2] + x[3];
}
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -910,7 +918,7 @@ constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
T erf_approx(T x) {
inline T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
@@ -918,10 +926,27 @@ T erf_approx(T x) {
return sign_x * y;
}
template<typename T> T elu_approx(T x);
template<> inline float elu_approx<float>(float x) {
return (x > 0.f) ? x : (exp(x) - 1);
}
template<> inline float4 elu_approx<float4>(float4 x) {
float4 res;
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
return res;
}
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
template <typename T0, typename T>
template <typename T0, typename T, typename TC>
kernel void kernel_unary_impl(
constant ggml_metal_kargs_unary & args,
device const char * src0,
@@ -963,111 +988,111 @@ kernel void kernel_unary_impl(
}
}
device const T0 & x = src0_ptr[i0];
const TC x = (TC) src0_ptr[i0];
if (FC_OP == OP_UNARY_NUM_SCALE) {
dst_ptr[i0] = args.scale * x + args.bias;
dst_ptr[i0] = (T) (args.scale * x + args.bias);
}
if (FC_OP == OP_UNARY_NUM_FILL) {
dst_ptr[i0] = args.val;
dst_ptr[i0] = (T) args.val;
}
if (FC_OP == OP_UNARY_NUM_CLAMP) {
dst_ptr[i0] = clamp(x, args.min, args.max);
dst_ptr[i0] = (T) clamp(x, args.min, args.max);
}
if (FC_OP == OP_UNARY_NUM_SQR) {
dst_ptr[i0] = x * x;
dst_ptr[i0] = (T) (x * x);
}
if (FC_OP == OP_UNARY_NUM_SQRT) {
dst_ptr[i0] = sqrt(x);
dst_ptr[i0] = (T) sqrt(x);
}
if (FC_OP == OP_UNARY_NUM_SIN) {
dst_ptr[i0] = sin(x);
dst_ptr[i0] = (T) sin(x);
}
if (FC_OP == OP_UNARY_NUM_COS) {
dst_ptr[i0] = cos(x);
dst_ptr[i0] = (T) cos(x);
}
if (FC_OP == OP_UNARY_NUM_LOG) {
dst_ptr[i0] = log(x);
dst_ptr[i0] = (T) log(x);
}
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(x * args.slope);
dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
}
if (FC_OP == OP_UNARY_NUM_TANH) {
dst_ptr[i0] = precise::tanh(x);
dst_ptr[i0] = (T) precise::tanh(x);
}
if (FC_OP == OP_UNARY_NUM_RELU) {
dst_ptr[i0] = fmax(0.0f, x);
dst_ptr[i0] = (T) fmax(0, x);
}
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
dst_ptr[i0] = 1.0f / (1.0f + exp(-x));
dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_GELU) {
dst_ptr[i0] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
}
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
dst_ptr[i0] = 0.5f*x*(1.0f + erf_approx(SQRT_2_INV*x));
dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
}
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
dst_ptr[i0] = x * (1.0f/(1.0f + exp(GELU_QUICK_COEF*x)));
dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
}
if (FC_OP == OP_UNARY_NUM_SILU) {
dst_ptr[i0] = x / (1.0f + exp(-x));
dst_ptr[i0] = (T) (x / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_ELU) {
dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(exp(x) - 1.0f);
dst_ptr[i0] = (T) elu_approx(x);
}
if (FC_OP == OP_UNARY_NUM_NEG) {
dst_ptr[i0] = -x;
dst_ptr[i0] = (T) -x;
}
if (FC_OP == OP_UNARY_NUM_ABS) {
dst_ptr[i0] = fabs(x);
dst_ptr[i0] = (T) fabs(x);
}
if (FC_OP == OP_UNARY_NUM_SGN) {
dst_ptr[i0] = T(x > 0.0f) - T(x < 0.0f);
dst_ptr[i0] = T(x > 0) - T(x < 0);
}
if (FC_OP == OP_UNARY_NUM_STEP) {
dst_ptr[i0] = T(x > 0.0f);
dst_ptr[i0] = T(x > 0);
}
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
dst_ptr[i0] = x * fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f));
dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
}
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
dst_ptr[i0] = fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f));
dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
}
if (FC_OP == OP_UNARY_NUM_EXP) {
dst_ptr[i0] = exp(x);
dst_ptr[i0] = (T) exp(x);
}
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
dst_ptr[i0] = select(log(1.0f + exp(x)), x, x > 20.0f);
dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
}
if (FC_OP == OP_UNARY_NUM_EXPM1) {
// TODO: precise implementation
dst_ptr[i0] = exp(x) - 1.0f;
dst_ptr[i0] = (T) (exp(x) - 1);
}
}
@@ -1075,11 +1100,12 @@ kernel void kernel_unary_impl(
#undef FC_CNT
}
typedef decltype(kernel_unary_impl<float, float>) kernel_unary_t;
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float>;
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4>;
typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
@@ -1483,33 +1509,35 @@ kernel void kernel_op_sum_f32(
}
}
template <bool norm>
kernel void kernel_sum_rows(
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
template <typename T0, typename T>
kernel void kernel_sum_rows_impl(
constant ggml_metal_kargs_sum_rows & args,
device const float * src0,
device float * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
device const char * src0,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int64_t i3 = tgpig.z;
int64_t i2 = tgpig.y;
int64_t i1 = tgpig.x;
#define FC_OP FC_sum_rows_op
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
shmem_t[tiisg] = 0.0f;
}
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
float sumf = 0;
T0 sumf = T0(0.0f);
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
@@ -1520,23 +1548,33 @@ kernel void kernel_sum_rows(
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
shmem_t[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = shmem_t[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
dst_row[0] = norm ? sumf / args.ne00 : sumf;
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
if (is_same<float4, T0>::value) {
dst_row[0] = sum(sumf) / (4*args.ne00);
} else {
dst_row[0] = sum(sumf) / args.ne00;
}
} else {
dst_row[0] = sum(sumf);
}
}
#undef FC_OP
}
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
template<typename T>
kernel void kernel_cumsum_blk(
@@ -2417,9 +2455,6 @@ kernel void kernel_solve_tri_f32(
const short K = FC_solve_tri_k;
const short NP = PAD2(N, NW);
const int32_t ne02 = args.ne02;
const int32_t ne03 = args.ne03;
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x*NSG + sgitg;
@@ -2706,26 +2741,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
kernel void kernel_l2_norm_f32(
template <typename T0, typename T>
kernel void kernel_l2_norm_impl(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort ntg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
@@ -2743,12 +2784,16 @@ kernel void kernel_l2_norm_f32(
const float scale = 1.0f/sqrt(max(sumf, args.eps));
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
y[i00] = x[i00] * scale;
}
}
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,
@@ -5921,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec(
static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
const short T = PK + NSG*SH; // shared memory size per query in (half)
//const short T = PK + NSG*SH; // shared memory size per query in (half)
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
@@ -8509,7 +8554,9 @@ kernel void kernel_mul_mm(
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
@@ -8632,8 +8679,8 @@ kernel void kernel_mul_mm(
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short dx = sx;
const short dy = sy;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
@@ -8882,7 +8929,9 @@ kernel void kernel_mul_mm_id(
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
@@ -9017,8 +9066,8 @@ kernel void kernel_mul_mm_id(
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short dx = sx;
const short dy = sy;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;

View File

@@ -85,6 +85,7 @@ set(GGML_OPENCL_KERNELS
mul_mv_q4_0_f32_8x_flat
mul_mv_q4_0_f32_1d_8x_flat
mul_mv_q4_0_f32_1d_16x_flat
mul_mv_q4_k_f32
mul_mv_q6_k_f32
mul_mv_q6_k_f32_flat
mul_mv_q8_0_f32
@@ -101,6 +102,7 @@ set(GGML_OPENCL_KERNELS
mul_mm_f32_f32_l4_lm
mul_mm_f16_f32_l4_lm
mul_mm_q8_0_f32_l4_lm
mul_mm_q6_k_f32_l4_lm
mul_mm_q8_0_f32_8x4
gemv_noshuffle_general_q8_0_f32
mul

View File

@@ -532,6 +532,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_restore_block_q4_0_noshuffle;
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
cl_kernel kernel_mul_mv_q4_K_f32;
cl_kernel kernel_mul_mv_q6_K_f32;
cl_kernel kernel_mul_mv_q6_K_f32_flat;
cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
@@ -564,6 +565,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
cl_kernel kernel_mul_mm_q6_k_f32_l4_lm;
std::vector<ProfilingInfo> profiling_info;
@@ -1117,6 +1119,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mv_q4_k_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mv_q4_k_f32.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mv_q4_k_f32.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mv_q6_k_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1358,6 +1377,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mm_q6_k_f32_l4_lm
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mm_q6_k_f32_l4_lm.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mm_q6_k_f32_l4_lm.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q6_k_f32_l4_lm", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mm_f16_f32_kq_kqv
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3364,6 +3400,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
} else if (op->src[0]->type == GGML_TYPE_F32) {
return op->src[1]->type == GGML_TYPE_F32;
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 ||
op->src[0]->type == GGML_TYPE_Q4_K ||
op->src[0]->type == GGML_TYPE_Q6_K) {
return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
} else if (op->src[0]->type == GGML_TYPE_Q8_0) {
@@ -8927,6 +8964,50 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
return;
}
case GGML_TYPE_Q6_K: {
if (ne11 < 32) {
break;
}
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
break;
}
kernel = backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm;
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
int batch_stride_a = ne00*ne01;
int batch_stride_b = ne10*ne11;
int batch_stride_d = ne0*ne1;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3));
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
size_t local_work_size[] = {(size_t)nth0, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
return;
}
default:
break;
}
@@ -9262,7 +9343,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
}
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q4_K: {
kernel = backend_ctx->kernel_mul_mv_q4_K_f32;
if (backend_ctx->gpu_family == INTEL) {
nth0 = 16;
nth1 = 1;
ndst = 4;
} else if (backend_ctx->gpu_family == ADRENO) {
nth0 = 64;
nth1 = 1;
ndst = 4;
} else {
GGML_ASSERT(false && "TODO: Unknown GPU");
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
break;
}
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
#ifdef GGML_OPENCL_SOA_Q
@@ -9424,7 +9540,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
} else if (src0t == GGML_TYPE_Q4_K) {
GGML_ASSERT(false && "not implemented");
size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
} else if (src0t == GGML_TYPE_Q3_K) {
GGML_ASSERT(false && "not implemented");
} else if (src0t == GGML_TYPE_Q5_K) {

View File

@@ -0,0 +1,158 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 2
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q6_k_f32_l4_lm(
global uchar * src0_ql,
global uchar * src0_qh,
global char * src0_s,
global half * src0_d,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 128; // 2 values per idx
int iqs = idx % 128; // 0..127
int n = iqs / 64; // 0,1
int b = (iqs % 64) / 32; // 0,1
int is_b = (iqs % 16) / 8; // 0,1
int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
int is = 8 * n + qhshift + is_b; // 0..15
int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is];
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32);
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32);
} else {
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}

View File

@@ -0,0 +1,180 @@
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
//------------------------------------------------------------------------------
// block_q4_K
//------------------------------------------------------------------------------
#define QK_K 256
#define K_SCALE_SIZE 12
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uchar qs[QK_K/2]; // 4-bit quants
} block_q4_K;
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // number of rows each SIMD group works on
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 16 // SIMD group size
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
#undef BLOCK_STRIDE
// number of (super) blocks each subgroup processes
// each thread in a subgroup processes a block (32 weights)
#define BLOCK_STRIDE (N_SIMDWIDTH/8)
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q4_K_f32(
global char * src0,
int offset0,
global char * src1,
int offset1,
global char * dst,
int offsetd,
int ne00,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
ushort kmask1 = 0x3f3f;
ushort kmask2 = 0x0f0f;
ushort kmask3 = 0xc0c0;
int ix = get_sub_group_local_id()/8; // super block index
int it = get_sub_group_local_id()%8; // block index (inside super block)
int iq = it/4; // 0 or 1 - first or second half of the super block
int ir = it%4; // 0...3 - block index in the half super block
int nb = ne00/QK_K;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0);
global float * y = (global float *) (src1 + offset_src1);
float yl[16];
float yh[16];
float sumf[N_DST] = {0.f};
float all_sum;
global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
ushort sc16[4];
uchar * sc8 = (uchar *)sc16;
for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
float4 sumy = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; ++i) {
yl[i+0] = y4[i+0];
sumy.s0 += yl[i+0];
yl[i+8] = y4[i+32];
sumy.s1 += yl[i+8];
yh[i+0] = y4[i+128];
sumy.s2 += yh[i+0];
yh[i+8] = y4[i+160];
sumy.s3 += yh[i+8];
}
global ushort * sc = (global ushort *)x[ib].scales + iq;
global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir;
global half * dh = &x[ib].d;
for (int row = 0; row < N_DST; row++) {
sc16[0] = sc[0] & kmask1;
sc16[1] = sc[2] & kmask1;
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
global ushort * q2 = q1 + 32;
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F);
acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00);
acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0);
acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000);
acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F);
acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00);
acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0);
acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000);
}
float dall = dh[0];
float dmin = dh[1];
sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
(acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
(acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
(acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
q1 += nb01/2;
sc += nb01/2;
dh += nb01/2;
}
y4 += BLOCK_STRIDE * QK_K;
}
global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = sub_group_reduce_add(sumf[row]);
if (first_row + row < ne01) {
if (get_sub_group_local_id() == 0) {
dst_f32[first_row + row] = all_sum;
}
}
}
}

View File

@@ -5749,7 +5749,7 @@ static struct ggml_tensor * ggml_unary_impl(
struct ggml_tensor * a,
enum ggml_unary_op op,
bool inplace) {
GGML_ASSERT(ggml_is_contiguous_1(a));
GGML_ASSERT(ggml_is_contiguous_rows(a));
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

View File

@@ -3766,6 +3766,7 @@ class VisionProjectorType:
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"
KIMIK25 = "kimik25"
LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm"
JANUS_PRO = "janus_pro"

View File

@@ -1303,6 +1303,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MMPROJ: (
"multi_modal_projector.linear_{bid}",
"mm_projector.proj.linear_{bid}", # Kimi-K2.5
"visual.merger.mlp.{bid}", # qwen2vl
"merger.mlp.{bid}",
),
@@ -1364,6 +1365,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_ATTN_QKV: (
"visual.blocks.{bid}.attn.qkv", # qwen3vl
"model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
"vision_tower.encoder.blocks.{bid}.wqkv" # Kimi-K2.5
),
MODEL_TENSOR.V_ENC_ATTN_Q: (
@@ -1538,6 +1540,7 @@ class TensorNameMap:
"multi_modal_projector.norm",
"multi_modal_projector.layer_norm",
"multi_modal_projector.pre_norm",
"mm_projector.pre_norm", # Kimi-K2.5
"pre_mm_projector_norm",
"model.vision.linear_proj.norm1", # cogvlm
"merger.ln_q",

View File

@@ -1943,7 +1943,11 @@ struct test_unary : public test_case {
ggml_tensor * a;
if (v & 1) {
auto ne = ne_a; ne[0] *= 3;
auto ne = ne_a;
ne[0] *= 3;
ne[1] *= 2;
ne[2] *= 5;
ne[3] *= 4;
a = ggml_new_tensor(ctx, type, 4, ne.data());
if (grad_supported) {
ggml_set_param(a);
@@ -8128,24 +8132,30 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
test_cases.emplace_back(new test_sum());
test_cases.emplace_back(new test_sum_rows());
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mean());
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 1, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32768, 1, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
test_cases.emplace_back(new test_sum_rows());
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
test_cases.emplace_back(new test_mean());
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, false));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, false, true));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, true));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1}));

View File

@@ -19,6 +19,7 @@ add_library(mtmd
models/glm4v.cpp
models/internvl.cpp
models/kimivl.cpp
models/kimik25.cpp
models/llama4.cpp
models/llava.cpp
models/minicpmv.cpp

View File

@@ -235,6 +235,7 @@ enum projector_type {
PROJECTOR_TYPE_LFM2A,
PROJECTOR_TYPE_GLM4V,
PROJECTOR_TYPE_YOUTUVL,
PROJECTOR_TYPE_KIMIK25,
PROJECTOR_TYPE_UNKNOWN,
};
@@ -268,6 +269,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LFM2A, "lfm2a"},
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
{ PROJECTOR_TYPE_YOUTUVL, "youtuvl"},
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {

View File

@@ -673,8 +673,8 @@ ggml_tensor * clip_graph::build_rope_2d(
{
first = ggml_view_3d(ctx0, cur,
n_dim/2, n_head, n_pos,
ggml_row_size(cur->type, n_dim),
ggml_row_size(cur->type, n_dim*n_head),
cur->nb[1],
cur->nb[2],
0);
first = ggml_rope_ext(
ctx0,
@@ -692,8 +692,8 @@ ggml_tensor * clip_graph::build_rope_2d(
{
second = ggml_view_3d(ctx0, cur,
n_dim/2, n_head, n_pos,
ggml_row_size(cur->type, n_dim),
ggml_row_size(cur->type, n_dim*n_head),
cur->nb[1],
cur->nb[2],
n_dim/2 * ggml_element_size(cur));
second = ggml_rope_ext(
ctx0,
@@ -826,6 +826,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
builder = std::make_unique<clip_graph_kimivl>(ctx, img);
} break;
case PROJECTOR_TYPE_KIMIK25:
{
builder = std::make_unique<clip_graph_kimik25>(ctx, img);
} break;
case PROJECTOR_TYPE_COGVLM:
{
builder = std::make_unique<clip_graph_cogvlm>(ctx, img);
@@ -1139,6 +1143,22 @@ struct clip_model_loader {
hparams.set_limit_image_tokens(8, 1024);
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
} break;
case PROJECTOR_TYPE_KIMIK25:
{
hparams.rope_theta = 10000.0f;
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
int min_pixels = 0, max_pixels = 0;
get_u32(KEY_IMAGE_MIN_PIXELS, min_pixels, false);
get_u32(KEY_IMAGE_MAX_PIXELS, max_pixels, false);
if (min_pixels > 0 && max_pixels > 0) {
hparams.image_min_pixels = min_pixels;
hparams.image_max_pixels = max_pixels;
hparams.warmup_image_size = static_cast<int>(std::sqrt(max_pixels));
} else {
hparams.set_limit_image_tokens(2, 4096);
}
} break;
case PROJECTOR_TYPE_GEMMA3:
{
// default value (used by all model sizes in gemma 3 family)
@@ -1668,6 +1688,7 @@ struct clip_model_loader {
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
} break;
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
{
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
@@ -3165,6 +3186,23 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->entries.push_back(std::move(res));
} break;
case PROJECTOR_TYPE_KIMIK25:
{
GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
original_size,
params.patch_size * params.n_merge,
params.image_min_pixels,
params.image_max_pixels);
const std::array<uint8_t, 3> pad_color = {0, 0, 0};
clip_image_u8 resized_img;
img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BICUBIC, true, pad_color);
clip_image_f32_ptr res(clip_image_f32_init());
normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(res));
} break;
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_MLP_NORM:
case PROJECTOR_TYPE_LDP:
@@ -3373,6 +3411,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
} break;
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
{
// dynamic size
int out_patch_size = params.patch_size * ctx->model.hparams.n_merge;
@@ -3714,6 +3753,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
case PROJECTOR_TYPE_LIGHTONOCR:
{
// set the 2D positions
@@ -3850,6 +3890,47 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
}
// Debug: dump final embeddings if MTMD_DEBUG_EMBEDDINGS is set
if (std::getenv("MTMD_DEBUG_EMBEDDINGS") != nullptr) {
const int64_t n_embd = embeddings->ne[0];
const int64_t n_tokens = embeddings->ne[1];
std::vector<float> emb_data(n_embd * n_tokens);
ggml_backend_tensor_get(embeddings, emb_data.data(), 0, ggml_nbytes(embeddings));
LOG_INF("\n=== MTMD_DEBUG_EMBEDDINGS ===\n");
LOG_INF("Shape: [%lld, %lld]\n", (long long)n_embd, (long long)n_tokens);
// Print first few values of first token
LOG_INF("Token 0 (first 16 values): ");
for (int i = 0; i < std::min((int64_t)16, n_embd); i++) {
LOG_INF("%.6f ", emb_data[i]);
}
LOG_INF("\n");
// Print last few values of first token
if (n_embd > 16) {
LOG_INF("Token 0 (last 16 values): ");
for (int64_t i = n_embd - 16; i < n_embd; i++) {
LOG_INF("%.6f ", emb_data[i]);
}
LOG_INF("\n");
}
// Compute and print statistics
float sum = 0.0f, sum_sq = 0.0f, min_val = emb_data[0], max_val = emb_data[0];
for (size_t i = 0; i < emb_data.size(); i++) {
sum += emb_data[i];
sum_sq += emb_data[i] * emb_data[i];
min_val = std::min(min_val, emb_data[i]);
max_val = std::max(max_val, emb_data[i]);
}
float mean = sum / emb_data.size();
float variance = (sum_sq / emb_data.size()) - (mean * mean);
LOG_INF("Stats: mean=%.6f, std=%.6f, min=%.6f, max=%.6f, sum=%.6f\n",
mean, sqrtf(variance), min_val, max_val, sum);
LOG_INF("=== END MTMD_DEBUG_EMBEDDINGS ===\n\n");
}
return true;
}
@@ -3896,6 +3977,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_COGVLM:
return ctx->model.mm_4h_to_h_w->ne[1];

View File

@@ -0,0 +1,101 @@
#include "models.h"
#include <cstring>
#include <cmath>
// note: this is similar to clip_graph::resize_position_embeddings, major difference is having
// the w/h in ne[1] and ne[2] instead of assuming with sqrt. Could try storing the tensor in 2D instead
// with a w*h? Also the permute is a bit different at (2, 1, 0, 3) instead of (2, 0, 1, 3).
ggml_tensor * clip_graph_kimik25::resize_position_embeddings_3d(uint32_t interpolation_mode) {
ggml_tensor * pos_embd = model.position_embeddings;
const int height = img.ny / patch_size;
const int width = img.nx / patch_size;
const uint32_t mode = interpolation_mode;
GGML_ASSERT(pos_embd);
const int64_t stored_c = pos_embd->ne[0]; // C = 1152
const int64_t orig_w = pos_embd->ne[1]; // W = 64
const int64_t orig_h = pos_embd->ne[2]; // H = 64
GGML_ASSERT(stored_c == n_embd);
if (height == (int)orig_h && width == (int)orig_w) {
// No interpolation needed, just flatten to [C, H*W]
return ggml_cont_2d(ctx0, pos_embd, n_embd, width * height);
}
pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3);
pos_embd = ggml_interpolate(ctx0, pos_embd, height, width, n_embd, 1, mode);
pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3);
pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height);
return pos_embd;
}
ggml_cgraph * clip_graph_kimik25::build() {
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_h, "pos_h");
ggml_set_input(pos_h);
ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_w, "pos_w");
ggml_set_input(pos_w);
ggml_tensor * learned_pos_embd = resize_position_embeddings_3d(GGML_SCALE_MODE_BICUBIC);
// Kimi-K2.5 uses interleaved 2D RoPE pattern natively, but
// Q / K are permuted during conversion to use split format.
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
cur = build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
return cur;
};
ggml_tensor * inp = build_inp();
// I don't know why, but doing this in the build_vit lead to the ggml_add not occurring?
// Doing it manually here does work.
inp = ggml_add(ctx0, inp, learned_pos_embd);
ggml_tensor * cur = build_vit(
inp, n_patches,
NORM_TYPE_NORMAL,
hparams.ffn_op,
nullptr,
add_pos);
cb(cur, "vit_out", -1);
{
// patch_merger
const int scale_factor = model.hparams.n_merge;
cur = build_patch_merge_permute(cur, scale_factor);
// projection norm
int proj_inp_dim = cur->ne[0];
int n_merged_patches = cur->ne[1];
cur = ggml_view_2d(ctx0, cur,
n_embd, n_merged_patches * scale_factor * scale_factor,
ggml_row_size(cur->type, n_embd), 0);
cur = ggml_norm(ctx0, cur, hparams.eps);
cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
cur = ggml_view_2d(ctx0, cur,
proj_inp_dim, n_merged_patches,
ggml_row_size(cur->type, proj_inp_dim), 0);
cb(cur, "proj_inp_normed", -1);
// projection mlp
cur = build_ffn(cur,
model.mm_1_w, model.mm_1_b,
nullptr, nullptr,
model.mm_2_w, model.mm_2_b,
FFN_GELU,
-1);
cb(cur, "proj_out", -1);
}
// build the graph
ggml_build_forward_expand(gf, cur);
return gf;
}

View File

@@ -109,3 +109,10 @@ struct clip_graph_mobilenetv5 : clip_graph {
ggml_tensor * inp,
const mobilenetv5_block & block);
};
struct clip_graph_kimik25 : clip_graph {
clip_graph_kimik25(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
ggml_tensor * resize_position_embeddings_3d(uint32_t interpolation_mode);
};

View File

@@ -19,7 +19,7 @@ Set of LLM REST APIs and a web UI to interact with llama.cpp.
* Speculative decoding
* Easy-to-use web UI
For the ful list of features, please refer to [server's changelog](https://github.com/ggml-org/llama.cpp/issues/9291)
For the full list of features, please refer to [server's changelog](https://github.com/ggml-org/llama.cpp/issues/9291)
## Usage

Binary file not shown.

View File

@@ -14,11 +14,11 @@
--popover-foreground: oklch(0.145 0 0);
--primary: oklch(0.205 0 0);
--primary-foreground: oklch(0.985 0 0);
--secondary: oklch(0.97 0 0);
--secondary: oklch(0.95 0 0);
--secondary-foreground: oklch(0.205 0 0);
--muted: oklch(0.97 0 0);
--muted-foreground: oklch(0.556 0 0);
--accent: oklch(0.97 0 0);
--accent: oklch(0.95 0 0);
--accent-foreground: oklch(0.205 0 0);
--destructive: oklch(0.577 0.245 27.325);
--border: oklch(0.875 0 0);
@@ -37,7 +37,7 @@
--sidebar-accent-foreground: oklch(0.205 0 0);
--sidebar-border: oklch(0.922 0 0);
--sidebar-ring: oklch(0.708 0 0);
--code-background: oklch(0.975 0 0);
--code-background: oklch(0.985 0 0);
--code-foreground: oklch(0.145 0 0);
--layer-popover: 1000000;
}
@@ -51,7 +51,7 @@
--popover-foreground: oklch(0.985 0 0);
--primary: oklch(0.922 0 0);
--primary-foreground: oklch(0.205 0 0);
--secondary: oklch(0.269 0 0);
--secondary: oklch(0.29 0 0);
--secondary-foreground: oklch(0.985 0 0);
--muted: oklch(0.269 0 0);
--muted-foreground: oklch(0.708 0 0);
@@ -116,12 +116,62 @@
--color-sidebar-ring: var(--sidebar-ring);
}
:root {
--chat-form-area-height: 8rem;
--chat-form-area-offset: 2rem;
--max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem)));
}
@media (min-width: 640px) {
:root {
--chat-form-area-height: 24rem;
--chat-form-area-offset: 12rem;
}
}
@layer base {
* {
@apply border-border outline-ring/50;
}
body {
@apply bg-background text-foreground;
scrollbar-width: thin;
scrollbar-gutter: stable;
}
/* Global scrollbar styling - visible only on hover */
* {
scrollbar-width: thin;
scrollbar-color: transparent transparent;
transition: scrollbar-color 0.2s ease;
}
*:hover {
scrollbar-color: hsl(var(--muted-foreground) / 0.3) transparent;
}
*::-webkit-scrollbar {
width: 6px;
height: 6px;
}
*::-webkit-scrollbar-track {
background: transparent;
}
*::-webkit-scrollbar-thumb {
background: transparent;
border-radius: 3px;
transition: background 0.2s ease;
}
*:hover::-webkit-scrollbar-thumb {
background: hsl(var(--muted-foreground) / 0.3);
}
*::-webkit-scrollbar-thumb:hover {
background: hsl(var(--muted-foreground) / 0.5);
}
}

View File

@@ -0,0 +1,48 @@
<script lang="ts">
import { Button } from '$lib/components/ui/button';
import * as Tooltip from '$lib/components/ui/tooltip';
import type { Component } from 'svelte';
interface Props {
icon: Component;
tooltip: string;
variant?: 'default' | 'destructive' | 'outline' | 'secondary' | 'ghost' | 'link';
size?: 'default' | 'sm' | 'lg' | 'icon';
class?: string;
disabled?: boolean;
onclick: () => void;
'aria-label'?: string;
}
let {
icon,
tooltip,
variant = 'ghost',
size = 'sm',
class: className = '',
disabled = false,
onclick,
'aria-label': ariaLabel
}: Props = $props();
</script>
<Tooltip.Root>
<Tooltip.Trigger>
<Button
{variant}
{size}
{disabled}
{onclick}
class="h-6 w-6 p-0 {className} flex"
aria-label={ariaLabel || tooltip}
>
{@const IconComponent = icon}
<IconComponent class="h-3 w-3" />
</Button>
</Tooltip.Trigger>
<Tooltip.Content>
<p>{tooltip}</p>
</Tooltip.Content>
</Tooltip.Root>

View File

@@ -0,0 +1,18 @@
<script lang="ts">
import { Copy } from '@lucide/svelte';
import { copyToClipboard } from '$lib/utils';
interface Props {
ariaLabel?: string;
canCopy?: boolean;
text: string;
}
let { ariaLabel = 'Copy to clipboard', canCopy = true, text }: Props = $props();
</script>
<Copy
class="h-3 w-3 flex-shrink-0 cursor-{canCopy ? 'pointer' : 'not-allowed'}"
aria-label={ariaLabel}
onclick={() => canCopy && copyToClipboard(text)}
/>

View File

@@ -0,0 +1,26 @@
<script lang="ts">
import { X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
interface Props {
id: string;
onRemove?: (id: string) => void;
class?: string;
}
let { id, onRemove, class: className = '' }: Props = $props();
</script>
<Button
type="button"
variant="ghost"
size="sm"
class="h-6 w-6 bg-white/20 p-0 hover:bg-white/30 {className}"
onclick={(e: MouseEvent) => {
e.stopPropagation();
onRemove?.(id);
}}
aria-label="Remove file"
>
<X class="h-3 w-3" />
</Button>

View File

@@ -0,0 +1,46 @@
<script lang="ts">
import { Eye } from '@lucide/svelte';
import ActionIconCopyToClipboard from '$lib/components/app/actions/ActionIconCopyToClipboard.svelte';
import { FileTypeText } from '$lib/enums';
interface Props {
code: string;
language: string;
disabled?: boolean;
onPreview?: (code: string, language: string) => void;
}
let { code, language, disabled = false, onPreview }: Props = $props();
const showPreview = $derived(language?.toLowerCase() === FileTypeText.HTML);
function handlePreview() {
if (disabled) return;
onPreview?.(code, language);
}
</script>
<div class="code-block-actions">
<div class="copy-code-btn" class:opacity-50={disabled} class:!cursor-not-allowed={disabled}>
<ActionIconCopyToClipboard
text={code}
canCopy={!disabled}
ariaLabel={disabled ? 'Code incomplete' : 'Copy code'}
/>
</div>
{#if showPreview}
<button
class="preview-code-btn"
class:opacity-50={disabled}
class:!cursor-not-allowed={disabled}
title={disabled ? 'Code incomplete' : 'Preview code'}
aria-label="Preview code"
aria-disabled={disabled}
type="button"
onclick={handlePreview}
>
<Eye size={16} />
</button>
{/if}
</div>

View File

@@ -0,0 +1,19 @@
/**
*
* ACTIONS
*
* Small interactive components for user actions.
*
*/
/** Styled icon button for action triggers with tooltip. */
export { default as ActionIcon } from './ActionIcon.svelte';
/** Code block actions component (copy, preview). */
export { default as ActionIconsCodeBlock } from './ActionIconsCodeBlock.svelte';
/** Copy-to-clipboard icon button with click handler. */
export { default as ActionIconCopyToClipboard } from './ActionIconCopyToClipboard.svelte';
/** Remove/delete icon button with X icon. */
export { default as ActionIconRemove } from './ActionIconRemove.svelte';

View File

@@ -0,0 +1,44 @@
<script lang="ts">
import { BadgeInfo } from '$lib/components/app';
import * as Tooltip from '$lib/components/ui/tooltip';
import { copyToClipboard } from '$lib/utils';
import type { Component } from 'svelte';
interface Props {
class?: string;
icon: Component;
value: string | number;
tooltipLabel?: string;
}
let { class: className = '', icon: Icon, value, tooltipLabel }: Props = $props();
function handleClick() {
void copyToClipboard(String(value));
}
</script>
{#if tooltipLabel}
<Tooltip.Root>
<Tooltip.Trigger>
<BadgeInfo class={className} onclick={handleClick}>
{#snippet icon()}
<Icon class="h-3 w-3" />
{/snippet}
{value}
</BadgeInfo>
</Tooltip.Trigger>
<Tooltip.Content>
<p>{tooltipLabel}</p>
</Tooltip.Content>
</Tooltip.Root>
{:else}
<BadgeInfo class={className} onclick={handleClick}>
{#snippet icon()}
<Icon class="h-3 w-3" />
{/snippet}
{value}
</BadgeInfo>
{/if}

View File

@@ -0,0 +1,27 @@
<script lang="ts">
import { cn } from '$lib/components/ui/utils';
import type { Snippet } from 'svelte';
interface Props {
children: Snippet;
class?: string;
icon?: Snippet;
onclick?: () => void;
}
let { children, class: className = '', icon, onclick }: Props = $props();
</script>
<button
class={cn(
'inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75',
className
)}
{onclick}
>
{#if icon}
{@render icon()}
{/if}
{@render children()}
</button>

View File

@@ -0,0 +1,39 @@
<script lang="ts">
import { ModelModality } from '$lib/enums';
import { MODALITY_ICONS, MODALITY_LABELS } from '$lib/constants/icons';
import { cn } from '$lib/components/ui/utils';
type DisplayableModality = ModelModality.VISION | ModelModality.AUDIO;
interface Props {
modalities: ModelModality[];
class?: string;
}
let { modalities, class: className = '' }: Props = $props();
// Filter to only modalities that have icons (VISION, AUDIO)
const displayableModalities = $derived(
modalities.filter(
(m): m is DisplayableModality => m === ModelModality.VISION || m === ModelModality.AUDIO
)
);
</script>
{#each displayableModalities as modality, index (index)}
{@const IconComponent = MODALITY_ICONS[modality]}
{@const label = MODALITY_LABELS[modality]}
<span
class={cn(
'inline-flex items-center gap-1 rounded-md bg-muted px-2 py-1 text-xs font-medium',
className
)}
>
{#if IconComponent}
<IconComponent class="h-3 w-3" />
{/if}
{label}
</span>
{/each}

View File

@@ -0,0 +1,16 @@
/**
*
* BADGES & INDICATORS
*
* Small visual indicators for status and metadata.
*
*/
/** Badge displaying chat statistics (tokens, timing). */
export { default as BadgeChatStatistic } from './BadgeChatStatistic.svelte';
/** Generic info badge with optional tooltip and click handler. */
export { default as BadgeInfo } from './BadgeInfo.svelte';
/** Badge indicating model modality (vision, audio, tools). */
export { default as BadgeModality } from './BadgeModality.svelte';

View File

@@ -0,0 +1,97 @@
<script lang="ts">
import ChevronsUpDownIcon from '@lucide/svelte/icons/chevrons-up-down';
import * as Collapsible from '$lib/components/ui/collapsible/index.js';
import { buttonVariants } from '$lib/components/ui/button/index.js';
import { Card } from '$lib/components/ui/card';
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
import type { Snippet } from 'svelte';
import type { Component } from 'svelte';
interface Props {
open?: boolean;
class?: string;
icon?: Component;
iconClass?: string;
title: string;
subtitle?: string;
isStreaming?: boolean;
onToggle?: () => void;
children: Snippet;
}
let {
open = $bindable(false),
class: className = '',
icon: Icon,
iconClass = 'h-4 w-4',
title,
subtitle,
isStreaming = false,
onToggle,
children
}: Props = $props();
let contentContainer: HTMLDivElement | undefined = $state();
const autoScroll = createAutoScrollController();
$effect(() => {
autoScroll.setContainer(contentContainer);
});
$effect(() => {
// Only auto-scroll when open and streaming
autoScroll.updateInterval(open && isStreaming);
});
function handleScroll() {
autoScroll.handleScroll();
}
</script>
<Collapsible.Root
{open}
onOpenChange={(value) => {
open = value;
onToggle?.();
}}
class={className}
>
<Card class="gap-0 border-muted bg-muted/30 py-0">
<Collapsible.Trigger class="flex w-full cursor-pointer items-center justify-between p-3">
<div class="flex items-center gap-2 text-muted-foreground">
{#if Icon}
<Icon class={iconClass} />
{/if}
<span class="font-mono text-sm font-medium">{title}</span>
{#if subtitle}
<span class="text-xs italic">{subtitle}</span>
{/if}
</div>
<div
class={buttonVariants({
variant: 'ghost',
size: 'sm',
class: 'h-6 w-6 p-0 text-muted-foreground hover:text-foreground'
})}
>
<ChevronsUpDownIcon class="h-4 w-4" />
<span class="sr-only">Toggle content</span>
</div>
</Collapsible.Trigger>
<Collapsible.Content>
<div
bind:this={contentContainer}
class="overflow-y-auto border-t border-muted px-3 pb-3"
onscroll={handleScroll}
style="min-height: var(--min-message-height); max-height: var(--max-message-height);"
>
{@render children()}
</div>
</Collapsible.Content>
</Card>
</Collapsible.Root>

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,95 @@
<script lang="ts">
import hljs from 'highlight.js';
import { browser } from '$app/environment';
import { mode } from 'mode-watcher';
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
import githubLightCss from 'highlight.js/styles/github.css?inline';
interface Props {
code: string;
language?: string;
class?: string;
maxHeight?: string;
maxWidth?: string;
}
let {
code,
language = 'text',
class: className = '',
maxHeight = '60vh',
maxWidth = ''
}: Props = $props();
let highlightedHtml = $state('');
function loadHighlightTheme(isDark: boolean) {
if (!browser) return;
const existingThemes = document.querySelectorAll('style[data-highlight-theme-preview]');
existingThemes.forEach((style) => style.remove());
const style = document.createElement('style');
style.setAttribute('data-highlight-theme-preview', 'true');
style.textContent = isDark ? githubDarkCss : githubLightCss;
document.head.appendChild(style);
}
$effect(() => {
const currentMode = mode.current;
const isDark = currentMode === 'dark';
loadHighlightTheme(isDark);
});
$effect(() => {
if (!code) {
highlightedHtml = '';
return;
}
try {
// Check if the language is supported
const lang = language.toLowerCase();
const isSupported = hljs.getLanguage(lang);
if (isSupported) {
const result = hljs.highlight(code, { language: lang });
highlightedHtml = result.value;
} else {
// Try auto-detection or fallback to plain text
const result = hljs.highlightAuto(code);
highlightedHtml = result.value;
}
} catch {
// Fallback to escaped plain text
highlightedHtml = code.replace(/&/g, '&amp;').replace(/</g, '&lt;').replace(/>/g, '&gt;');
}
});
</script>
<div
class="code-preview-wrapper rounded-lg border border-border bg-muted {className}"
style="max-height: {maxHeight}; max-width: {maxWidth};"
>
<!-- Needs to be formatted as single line for proper rendering -->
<pre class="m-0"><code class="hljs text-sm leading-relaxed">{@html highlightedHtml}</code></pre>
</div>
<style>
.code-preview-wrapper {
font-family:
ui-monospace, SFMono-Regular, 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas,
'Liberation Mono', Menlo, monospace;
}
.code-preview-wrapper pre {
background: transparent;
}
.code-preview-wrapper code {
background: transparent;
}
</style>

View File

@@ -0,0 +1,79 @@
/**
*
* CONTENT RENDERING
*
* Components for rendering rich content: markdown, code, and previews.
*
*/
/**
* **MarkdownContent** - Rich markdown renderer
*
* Renders markdown content with syntax highlighting, LaTeX math,
* tables, links, and code blocks. Optimized for streaming with
* incremental block-based rendering.
*
* **Features:**
* - GFM (GitHub Flavored Markdown): tables, task lists, strikethrough
* - LaTeX math via KaTeX (`$inline$` and `$$block$$`)
* - Syntax highlighting (highlight.js) with language detection
* - Code copy buttons with click feedback
* - External links open in new tab with security attrs
* - Image attachment resolution from message extras
* - Dark/light theme support (auto-switching)
* - Streaming-optimized incremental rendering
* - Code preview dialog for large blocks
*
* @example
* ```svelte
* <MarkdownContent content={message.content} attachments={message.extra} />
* ```
*/
export { default as MarkdownContent } from './MarkdownContent.svelte';
/**
* **SyntaxHighlightedCode** - Code syntax highlighting
*
* Renders code with syntax highlighting using highlight.js.
* Supports theme switching and scrollable containers.
*
* **Features:**
* - Auto language detection with fallback
* - Dark/light theme auto-switching
* - Scrollable container with configurable max dimensions
* - Monospace font styling
* - Preserves whitespace and formatting
*
* @example
* ```svelte
* <SyntaxHighlightedCode code={jsonString} language="json" />
* ```
*/
export { default as SyntaxHighlightedCode } from './SyntaxHighlightedCode.svelte';
/**
* **CollapsibleContentBlock** - Expandable content card
*
* Reusable collapsible card with header, icon, and auto-scroll.
* Used for tool calls and reasoning blocks in chat messages.
*
* **Features:**
* - Collapsible content with smooth animation
* - Custom icon and title display
* - Optional subtitle/status text
* - Auto-scroll during streaming (pauses on user scroll)
* - Configurable max height with overflow scroll
*
* @example
* ```svelte
* <CollapsibleContentBlock
* bind:open
* icon={BrainIcon}
* title="Thinking..."
* isStreaming={true}
* >
* {reasoningContent}
* </CollapsibleContentBlock>
* ```
*/
export { default as CollapsibleContentBlock } from './CollapsibleContentBlock.svelte';

View File

@@ -17,9 +17,13 @@
let { conversations, messageCountMap = new Map(), mode, onCancel, onConfirm }: Props = $props();
let searchQuery = $state('');
let selectedIds = $state.raw<SvelteSet<string>>(new SvelteSet(conversations.map((c) => c.id)));
let selectedIds = $state.raw<SvelteSet<string>>(getInitialSelectedIds());
let lastClickedId = $state<string | null>(null);
function getInitialSelectedIds(): SvelteSet<string> {
return new SvelteSet(conversations.map((c) => c.id));
}
let filteredConversations = $derived(
conversations.filter((conv) => {
const name = conv.name || 'Untitled conversation';
@@ -92,7 +96,7 @@
}
function handleCancel() {
selectedIds = new SvelteSet(conversations.map((c) => c.id));
selectedIds = getInitialSelectedIds();
searchQuery = '';
lastClickedId = null;
@@ -100,7 +104,7 @@
}
export function reset() {
selectedIds = new SvelteSet(conversations.map((c) => c.id));
selectedIds = getInitialSelectedIds();
searchQuery = '';
lastClickedId = null;
}

View File

@@ -0,0 +1,93 @@
<script lang="ts">
import { ChevronLeft, ChevronRight } from '@lucide/svelte';
interface Props {
class?: string;
children?: import('svelte').Snippet;
gapSize?: string;
onScrollableChange?: (isScrollable: boolean) => void;
}
let { class: className = '', children, gapSize = '3', onScrollableChange }: Props = $props();
let canScrollLeft = $state(false);
let canScrollRight = $state(false);
let scrollContainer: HTMLDivElement | undefined = $state();
function scrollLeft(event?: MouseEvent) {
event?.stopPropagation();
event?.preventDefault();
if (!scrollContainer) return;
scrollContainer.scrollBy({ left: scrollContainer.clientWidth * -0.67, behavior: 'smooth' });
}
function scrollRight(event?: MouseEvent) {
event?.stopPropagation();
event?.preventDefault();
if (!scrollContainer) return;
scrollContainer.scrollBy({ left: scrollContainer.clientWidth * 0.67, behavior: 'smooth' });
}
function updateScrollButtons() {
if (!scrollContainer) return;
const { scrollLeft, scrollWidth, clientWidth } = scrollContainer;
canScrollLeft = scrollLeft > 0;
canScrollRight = scrollLeft < scrollWidth - clientWidth - 1;
const isScrollable = scrollWidth > clientWidth;
onScrollableChange?.(isScrollable);
}
export function resetScroll() {
if (scrollContainer) {
scrollContainer.scrollLeft = 0;
setTimeout(() => {
updateScrollButtons();
}, 0);
}
}
$effect(() => {
if (scrollContainer) {
setTimeout(() => {
updateScrollButtons();
}, 0);
}
});
</script>
<div class="relative {className}">
<button
class="absolute top-1/2 left-4 z-10 flex h-6 w-6 -translate-y-1/2 items-center justify-center rounded-full bg-foreground/15 shadow-md backdrop-blur-xs transition-opacity hover:bg-foreground/35 {canScrollLeft
? 'opacity-100'
: 'pointer-events-none opacity-0'}"
onclick={scrollLeft}
aria-label="Scroll left"
>
<ChevronLeft class="h-4 w-4" />
</button>
<div
class="scrollbar-hide flex items-start gap-{gapSize} overflow-x-auto"
bind:this={scrollContainer}
onscroll={updateScrollButtons}
>
{@render children?.()}
</div>
<button
class="absolute top-1/2 right-4 z-10 flex h-6 w-6 -translate-y-1/2 items-center justify-center rounded-full bg-foreground/15 shadow-md backdrop-blur-xs transition-opacity hover:bg-foreground/35 {canScrollRight
? 'opacity-100'
: 'pointer-events-none opacity-0'}"
onclick={scrollRight}
aria-label="Scroll right"
>
<ChevronRight class="h-4 w-4" />
</button>
</div>

View File

@@ -11,7 +11,9 @@
let baseClasses =
'px-1 pointer-events-none inline-flex select-none items-center gap-0.5 font-sans text-md font-medium opacity-0 transition-opacity -my-1';
let variantClasses = variant === 'destructive' ? 'text-destructive' : 'text-muted-foreground';
let variantClasses = $derived(
variant === 'destructive' ? 'text-destructive' : 'text-muted-foreground'
);
</script>
<kbd class="{baseClasses} {variantClasses} {className}">

View File

@@ -0,0 +1,48 @@
<script lang="ts">
import * as Tooltip from '$lib/components/ui/tooltip';
interface Props {
text: string;
class?: string;
}
let { text, class: className = '' }: Props = $props();
let textElement: HTMLSpanElement | undefined = $state();
let isTruncated = $state(false);
function checkTruncation() {
if (textElement) {
isTruncated = textElement.scrollWidth > textElement.clientWidth;
}
}
$effect(() => {
if (textElement) {
checkTruncation();
const observer = new ResizeObserver(checkTruncation);
observer.observe(textElement);
return () => observer.disconnect();
}
});
</script>
{#if isTruncated}
<Tooltip.Root>
<Tooltip.Trigger class={className}>
<span bind:this={textElement} class="block truncate">
{text}
</span>
</Tooltip.Trigger>
<Tooltip.Content class="z-[9999]">
<p>{text}</p>
</Tooltip.Content>
</Tooltip.Root>
{:else}
<span bind:this={textElement} class="{className} block truncate">
{text}
</span>
{/if}

View File

@@ -0,0 +1,45 @@
/**
*
* MISC
*
* Miscellaneous utility components.
*
*/
/**
* **ConversationSelection** - Multi-select conversation picker
*
* List of conversations with checkboxes for multi-selection.
* Used in import/export dialogs for selecting conversations.
*
* **Features:**
* - Search/filter conversations by name
* - Select all / deselect all controls
* - Shift-click for range selection
* - Message count display per conversation
* - Mode-specific UI (export vs import)
*/
export { default as ConversationSelection } from './ConversationSelection.svelte';
/**
* Horizontal scrollable carousel with navigation arrows.
* Used for displaying items in a horizontally scrollable container
* with left/right navigation buttons that appear on hover.
*/
export { default as HorizontalScrollCarousel } from './HorizontalScrollCarousel.svelte';
/**
* **TruncatedText** - Text with ellipsis and tooltip
*
* Displays text with automatic truncation and full content in tooltip.
* Useful for long names or paths in constrained spaces.
*/
export { default as TruncatedText } from './TruncatedText.svelte';
/**
* **KeyboardShortcutInfo** - Keyboard shortcut hint display
*
* Displays keyboard shortcut hints (e.g., "⌘ + Enter").
* Supports special keys like shift, cmd, and custom text.
*/
export { default as KeyboardShortcutInfo } from './KeyboardShortcutInfo.svelte';

View File

@@ -0,0 +1,86 @@
<script lang="ts">
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import * as Tooltip from '$lib/components/ui/tooltip';
import { KeyboardShortcutInfo } from '$lib/components/app';
import type { Component } from 'svelte';
interface ActionItem {
icon: Component;
label: string;
onclick: (event: Event) => void;
variant?: 'default' | 'destructive';
disabled?: boolean;
shortcut?: string[];
separator?: boolean;
}
interface Props {
triggerIcon: Component;
triggerTooltip?: string;
triggerClass?: string;
actions: ActionItem[];
align?: 'start' | 'center' | 'end';
open?: boolean;
}
let {
triggerIcon,
triggerTooltip,
triggerClass = '',
actions,
align = 'end',
open = $bindable(false)
}: Props = $props();
</script>
<DropdownMenu.Root bind:open>
<DropdownMenu.Trigger
class="flex h-6 w-6 cursor-pointer items-center justify-center rounded-md p-0 text-sm font-medium transition-colors hover:bg-accent hover:text-accent-foreground focus:bg-accent focus:text-accent-foreground focus:outline-none disabled:pointer-events-none disabled:opacity-50 data-[state=open]:bg-accent data-[state=open]:text-accent-foreground {triggerClass}"
onclick={(e) => e.stopPropagation()}
>
{#if triggerTooltip}
<Tooltip.Root>
<Tooltip.Trigger>
{@render iconComponent(triggerIcon, 'h-3 w-3')}
<span class="sr-only">{triggerTooltip}</span>
</Tooltip.Trigger>
<Tooltip.Content>
<p>{triggerTooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
{:else}
{@render iconComponent(triggerIcon, 'h-3 w-3')}
{/if}
</DropdownMenu.Trigger>
<DropdownMenu.Content {align} class="z-[999999] w-48">
{#each actions as action, index (action.label)}
{#if action.separator && index > 0}
<DropdownMenu.Separator />
{/if}
<DropdownMenu.Item
onclick={action.onclick}
variant={action.variant}
disabled={action.disabled}
class="flex items-center justify-between hover:[&>kbd]:opacity-100"
>
<div class="flex items-center gap-2">
{@render iconComponent(
action.icon,
`h-4 w-4 ${action.variant === 'destructive' ? 'text-destructive' : ''}`
)}
{action.label}
</div>
{#if action.shortcut}
<KeyboardShortcutInfo keys={action.shortcut} variant={action.variant} />
{/if}
</DropdownMenu.Item>
{/each}
</DropdownMenu.Content>
</DropdownMenu.Root>
{#snippet iconComponent(IconComponent: Component, className: string)}
<IconComponent class={className} />
{/snippet}

View File

@@ -0,0 +1,50 @@
<script lang="ts">
import type { Snippet } from 'svelte';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import { SearchInput } from '$lib/components/app';
interface Props {
placeholder?: string;
searchValue?: string;
onSearchChange?: (value: string) => void;
onSearchKeyDown?: (event: KeyboardEvent) => void;
emptyMessage?: string;
isEmpty?: boolean;
children: Snippet;
footer?: Snippet;
}
let {
placeholder = 'Search...',
searchValue = $bindable(''),
onSearchChange,
onSearchKeyDown,
emptyMessage = 'No items found',
isEmpty = false,
children,
footer
}: Props = $props();
</script>
<div class="sticky top-0 z-10 mb-2 bg-popover p-1 pt-2">
<SearchInput
{placeholder}
bind:value={searchValue}
onInput={onSearchChange}
onKeyDown={onSearchKeyDown}
/>
</div>
<div class="overflow-y-auto">
{@render children()}
{#if isEmpty}
<div class="px-2 py-3 text-center text-sm text-muted-foreground">{emptyMessage}</div>
{/if}
</div>
{#if footer}
<DropdownMenu.Separator />
{@render footer()}
{/if}

View File

@@ -0,0 +1,65 @@
/**
*
* NAVIGATION & MENUS
*
* Components for dropdown menus and action selection.
*
*/
/**
* **DropdownMenuSearchable** - Searchable content for dropdown menus
*
* Renders a search input with filtered content area, empty state, and optional footer.
* Designed to be injected into any dropdown container (DropdownMenu.Content,
* DropdownMenu.SubContent, etc.) without providing its own Root.
*
* **Features:**
* - Search/filter input
* - Keyboard navigation support
* - Custom content and footer via snippets
* - Empty state message
*
* @example
* ```svelte
* <DropdownMenu.Root>
* <DropdownMenu.Trigger>...</DropdownMenu.Trigger>
* <DropdownMenu.Content class="pt-0">
* <DropdownMenuSearchable
* bind:searchValue
* placeholder="Search..."
* isEmpty={filteredItems.length === 0}
* >
* {#each items as item}<Item {item} />{/each}
* </DropdownMenuSearchable>
* </DropdownMenu.Content>
* </DropdownMenu.Root>
* ```
*/
export { default as DropdownMenuSearchable } from './DropdownMenuSearchable.svelte';
/**
* **DropdownMenuActions** - Multi-action dropdown menu
*
* Dropdown menu for multiple action options with icons and shortcuts.
* Supports destructive variants and keyboard shortcut hints.
*
* **Features:**
* - Configurable trigger icon with tooltip
* - Action items with icons and labels
* - Destructive variant styling
* - Keyboard shortcut display
* - Separator support between groups
*
* @example
* ```svelte
* <DropdownMenuActions
* triggerIcon={MoreHorizontal}
* triggerTooltip="More actions"
* actions={[
* { icon: Edit, label: 'Edit', onclick: handleEdit },
* { icon: Trash, label: 'Delete', onclick: handleDelete, variant: 'destructive' }
* ]}
* />
* ```
*/
export { default as DropdownMenuActions } from './DropdownMenuActions.svelte';

View File

@@ -8,6 +8,7 @@
import { serverStore, serverLoading } from '$lib/stores/server.svelte';
import { config, settingsStore } from '$lib/stores/settings.svelte';
import { fade, fly, scale } from 'svelte/transition';
import { KeyboardKey } from '$lib/enums/keyboard';
interface Props {
class?: string;
@@ -117,7 +118,7 @@
}
function handleApiKeyKeydown(event: KeyboardEvent) {
if (event.key === 'Enter') {
if (event.key === KeyboardKey.ENTER) {
handleSaveApiKey();
}
}

View File

@@ -48,7 +48,7 @@
{model || 'Unknown Model'}
</Badge>
{#if serverData.default_generation_settings.n_ctx}
{#if serverData?.default_generation_settings?.n_ctx}
<Badge variant="secondary" class="text-xs">
ctx: {serverData.default_generation_settings.n_ctx.toLocaleString()}
</Badge>

View File

@@ -0,0 +1,80 @@
/**
*
* SERVER
*
* Components for displaying server connection state and handling
* connection errors. Integrates with serverStore for state management.
*
*/
/**
* **ServerStatus** - Server connection status indicator
*
* Compact status display showing connection state, model name,
* and context size. Used in headers and loading screens.
*
* **Architecture:**
* - Reads state from serverStore (props, loading, error)
* - Displays model name from modelsStore
*
* **Features:**
* - Status dot: green (connected), yellow (connecting), red (error), gray (unknown)
* - Status text label
* - Model name badge with icon
* - Context size badge
* - Optional error action button
*
* @example
* ```svelte
* <ServerStatus showActions />
* ```
*/
export { default as ServerStatus } from './ServerStatus.svelte';
/**
* **ServerErrorSplash** - Full-screen connection error display
*
* Blocking error screen shown when server connection fails.
* Provides retry options and API key input for authentication errors.
*
* **Architecture:**
* - Detects access denied errors for API key flow
* - Validates API key against server before saving
* - Integrates with settingsStore for API key persistence
*
* **Features:**
* - Error message display with icon
* - Retry connection button with loading state
* - API key input for authentication errors
* - API key validation with success/error feedback
* - Troubleshooting section with server start commands
* - Animated transitions for UI elements
*
* @example
* ```svelte
* <ServerErrorSplash
* error={serverError}
* onRetry={handleRetry}
* showTroubleshooting
* />
* ```
*/
export { default as ServerErrorSplash } from './ServerErrorSplash.svelte';
/**
* **ServerLoadingSplash** - Full-screen loading display
*
* Shown during initial server connection. Displays loading animation
* with ServerStatus component for real-time connection state.
*
* **Features:**
* - Animated server icon
* - Customizable loading message
* - Embedded ServerStatus for live updates
*
* @example
* ```svelte
* <ServerLoadingSplash message="Connecting to server..." />
* ```
*/
export { default as ServerLoadingSplash } from './ServerLoadingSplash.svelte';

View File

@@ -42,7 +42,7 @@
bind:this={ref}
data-slot="badge"
{href}
class={cn(badgeVariants({ variant }), className)}
class={cn(badgeVariants({ variant }), className, 'backdrop-blur-sm')}
{...restProps}
>
{@render children?.()}

View File

@@ -12,8 +12,9 @@
'bg-destructive shadow-xs hover:bg-destructive/90 focus-visible:ring-destructive/20 dark:focus-visible:ring-destructive/40 dark:bg-destructive/60 text-white',
outline:
'bg-background shadow-xs hover:bg-accent hover:text-accent-foreground dark:bg-input/30 dark:border-input dark:hover:bg-input/50 border',
secondary: 'bg-secondary text-secondary-foreground shadow-xs hover:bg-secondary/80',
ghost: 'hover:bg-accent hover:text-accent-foreground dark:hover:bg-accent/50',
secondary:
'dark:bg-secondary dark:text-secondary-foreground bg-background shadow-sm text-foreground hover:bg-muted-foreground/20',
ghost: 'hover:text-accent-foreground hover:bg-muted-foreground/10',
link: 'text-primary underline-offset-4 hover:underline'
},
size: {

View File

@@ -1,6 +1,7 @@
<script lang="ts">
import type { HTMLAttributes } from 'svelte/elements';
import { cn, type WithElementRef } from '$lib/components/ui/utils';
import { BOX_BORDER } from '$lib/constants/css-classes';
let {
ref = $bindable(null),
@@ -14,7 +15,8 @@
bind:this={ref}
data-slot="card"
class={cn(
'flex flex-col gap-6 rounded-xl border bg-card py-6 text-card-foreground shadow-sm',
'flex flex-col gap-6 rounded-xl bg-card py-6 text-card-foreground shadow-sm',
BOX_BORDER,
className
)}
{...restProps}

View File

@@ -19,7 +19,7 @@
data-slot="dropdown-menu-content"
{sideOffset}
class={cn(
'z-50 max-h-(--bits-dropdown-menu-content-available-height) min-w-[8rem] origin-(--bits-dropdown-menu-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border border-border bg-popover p-1 text-popover-foreground shadow-md outline-none data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 dark:border-border/20',
'z-50 max-h-(--bits-dropdown-menu-content-available-height) min-w-[8rem] origin-(--bits-dropdown-menu-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border border-border bg-popover p-1.5 text-popover-foreground shadow-md outline-none data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 dark:border-border/20',
className
)}
{...restProps}

View File

@@ -44,6 +44,7 @@
'aria-invalid:border-destructive aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40',
className
)}
style="backdrop-filter: blur(0.5rem);"
{type}
bind:value
{...restProps}

View File

@@ -1,6 +1,5 @@
<script lang="ts">
import { Button } from '$lib/components/ui/button/index.js';
import { cn } from '$lib/components/ui/utils.js';
import PanelLeftIcon from '@lucide/svelte/icons/panel-left';
import type { ComponentProps } from 'svelte';
import { useSidebar } from './context.svelte.js';
@@ -22,7 +21,7 @@
data-slot="sidebar-trigger"
variant="ghost"
size="icon"
class={cn('size-7', className)}
class="rounded-full backdrop-blur-lg {className} h-9! w-9!"
type="button"
onclick={(e) => {
onclick?.(e);

View File

@@ -15,7 +15,7 @@
bind:checked
data-slot="switch"
class={cn(
'peer inline-flex h-[1.15rem] w-8 shrink-0 items-center rounded-full border border-transparent shadow-xs transition-all outline-none focus-visible:border-ring focus-visible:ring-[3px] focus-visible:ring-ring/50 disabled:cursor-not-allowed disabled:opacity-50 data-[state=checked]:bg-primary data-[state=unchecked]:bg-input dark:data-[state=unchecked]:bg-input/80',
'peer inline-flex h-[1.15rem] w-8 shrink-0 cursor-pointer items-center rounded-full border border-transparent shadow-xs transition-all outline-none focus-visible:border-ring focus-visible:ring-[3px] focus-visible:ring-ring/50 disabled:cursor-not-allowed disabled:opacity-50 data-[state=checked]:bg-primary data-[state=unchecked]:bg-input dark:data-[state=unchecked]:bg-input/80',
className
)}
{...restProps}

View File

@@ -9,22 +9,28 @@
side = 'top',
children,
arrowClasses,
noPortal = false,
...restProps
}: TooltipPrimitive.ContentProps & {
arrowClasses?: string;
noPortal?: boolean;
} = $props();
const contentClass = $derived(
cn(
'z-50 w-fit origin-(--bits-tooltip-content-transform-origin) animate-in rounded-md bg-primary px-3 py-1.5 text-xs text-balance text-primary-foreground fade-in-0 zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95',
className
)
);
</script>
<TooltipPrimitive.Portal>
{#snippet tooltipContent()}
<TooltipPrimitive.Content
bind:ref
data-slot="tooltip-content"
{sideOffset}
{side}
class={cn(
'z-50 w-fit origin-(--bits-tooltip-content-transform-origin) animate-in rounded-md bg-primary px-3 py-1.5 text-xs text-balance text-primary-foreground fade-in-0 zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95',
className
)}
class={contentClass}
{...restProps}
>
{@render children?.()}
@@ -44,4 +50,12 @@
{/snippet}
</TooltipPrimitive.Arrow>
</TooltipPrimitive.Content>
</TooltipPrimitive.Portal>
{/snippet}
{#if noPortal}
{@render tooltipContent()}
{:else}
<TooltipPrimitive.Portal>
{@render tooltipContent()}
</TooltipPrimitive.Portal>
{/if}

View File

@@ -1,9 +1,6 @@
export interface BinaryDetectionOptions {
/** Number of characters to check from the beginning of the file */
prefixLength: number;
/** Maximum ratio of suspicious characters allowed (0.0 to 1.0) */
suspiciousCharThresholdRatio: number;
/** Maximum absolute number of null bytes allowed */
maxAbsoluteNullBytes: number;
}

View File

@@ -0,0 +1,3 @@
export const INITIAL_FILE_SIZE = 0;
export const PROMPT_CONTENT_SEPARATOR = '\n\n';
export const CLIPBOARD_CONTENT_QUOTE_PREFIX = '"';

View File

@@ -0,0 +1,8 @@
export const CODE_BLOCK_SCROLL_CONTAINER_CLASS = 'code-block-scroll-container';
export const CODE_BLOCK_WRAPPER_CLASS = 'code-block-wrapper';
export const CODE_BLOCK_HEADER_CLASS = 'code-block-header';
export const CODE_BLOCK_ACTIONS_CLASS = 'code-block-actions';
export const CODE_LANGUAGE_CLASS = 'code-language';
export const COPY_CODE_BTN_CLASS = 'copy-code-btn';
export const PREVIEW_CODE_BTN_CLASS = 'preview-code-btn';
export const RELATIVE_CLASS = 'relative';

View File

@@ -0,0 +1,7 @@
export const NEWLINE = '\n';
export const DEFAULT_LANGUAGE = 'text';
export const LANG_PATTERN = /^(\w*)\n?/;
export const AMPERSAND_REGEX = /&/g;
export const LT_REGEX = /</g;
export const GT_REGEX = />/g;
export const FENCE_PATTERN = /^```|\n```/g;

View File

@@ -0,0 +1,10 @@
export const BOX_BORDER =
'border border-border/30 focus-within:border-border dark:border-border/20 dark:focus-within:border-border';
export const INPUT_CLASSES = `
bg-muted/60 dark:bg-muted/75
${BOX_BORDER}
shadow-sm
outline-none
text-foreground
`;

View File

@@ -0,0 +1,8 @@
export const MS_PER_SECOND = 1000;
export const SECONDS_PER_MINUTE = 60;
export const SECONDS_PER_HOUR = 3600;
export const SHORT_DURATION_THRESHOLD = 1;
export const MEDIUM_DURATION_THRESHOLD = 10;
/** Default display value when no performance time is available */
export const DEFAULT_PERFORMANCE_TIME = '0s';

View File

@@ -0,0 +1,4 @@
export const IMAGE_NOT_ERROR_BOUND_SELECTOR = 'img:not([data-error-bound])';
export const DATA_ERROR_BOUND_ATTR = 'errorBound';
export const DATA_ERROR_HANDLED_ATTR = 'errorHandled';
export const BOOL_TRUE_STRING = 'true';

View File

@@ -1 +1,8 @@
export const PROCESSING_INFO_TIMEOUT = 2000;
/**
* Statistics units labels
*/
export const STATS_UNITS = {
TOKENS_PER_SECOND: 't/s'
} as const;

View File

@@ -0,0 +1,33 @@
/**
* List of all numeric fields in settings configuration.
* These fields will be converted from strings to numbers during save.
*/
export const NUMERIC_FIELDS = [
'temperature',
'top_k',
'top_p',
'min_p',
'max_tokens',
'pasteLongTextToFileLen',
'dynatemp_range',
'dynatemp_exponent',
'typ_p',
'xtc_probability',
'xtc_threshold',
'repeat_last_n',
'repeat_penalty',
'presence_penalty',
'frequency_penalty',
'dry_multiplier',
'dry_base',
'dry_allowed_length',
'dry_penalty_last_n',
'agenticMaxTurns',
'agenticMaxToolPreviewLines'
] as const;
/**
* Fields that must be positive integers (>= 1).
* These will be clamped to minimum 1 and rounded during save.
*/
export const POSITIVE_INTEGER_FIELDS = ['agenticMaxTurns', 'agenticMaxToolPreviewLines'] as const;

View File

@@ -1 +1 @@
export const TOOLTIP_DELAY_DURATION = 100;
export const TOOLTIP_DELAY_DURATION = 500;

View File

@@ -0,0 +1 @@
export const SYSTEM_MESSAGE_PLACEHOLDER = 'System message';

View File

@@ -0,0 +1,34 @@
import { getContext, setContext } from 'svelte';
export interface ChatActionsContext {
copy: (message: DatabaseMessage) => void;
delete: (message: DatabaseMessage) => void;
navigateToSibling: (siblingId: string) => void;
editWithBranching: (
message: DatabaseMessage,
newContent: string,
newExtras?: DatabaseMessageExtra[]
) => void;
editWithReplacement: (
message: DatabaseMessage,
newContent: string,
shouldBranch: boolean
) => void;
editUserMessagePreserveResponses: (
message: DatabaseMessage,
newContent: string,
newExtras?: DatabaseMessageExtra[]
) => void;
regenerateWithBranching: (message: DatabaseMessage, modelOverride?: string) => void;
continueAssistantMessage: (message: DatabaseMessage) => void;
}
const CHAT_ACTIONS_KEY = Symbol.for('chat-actions');
export function setChatActionsContext(ctx: ChatActionsContext): ChatActionsContext {
return setContext(CHAT_ACTIONS_KEY, ctx);
}
export function getChatActionsContext(): ChatActionsContext {
return getContext(CHAT_ACTIONS_KEY);
}

View File

@@ -0,0 +1,13 @@
export {
getMessageEditContext,
setMessageEditContext,
type MessageEditContext,
type MessageEditState,
type MessageEditActions
} from './message-edit.context';
export {
getChatActionsContext,
setChatActionsContext,
type ChatActionsContext
} from './chat-actions.context';

View File

@@ -0,0 +1,39 @@
import { getContext, setContext } from 'svelte';
export interface MessageEditState {
readonly isEditing: boolean;
readonly editedContent: string;
readonly editedExtras: DatabaseMessageExtra[];
readonly editedUploadedFiles: ChatUploadedFile[];
readonly originalContent: string;
readonly originalExtras: DatabaseMessageExtra[];
readonly showSaveOnlyOption: boolean;
}
export interface MessageEditActions {
setContent: (content: string) => void;
setExtras: (extras: DatabaseMessageExtra[]) => void;
setUploadedFiles: (files: ChatUploadedFile[]) => void;
save: () => void;
saveOnly: () => void;
cancel: () => void;
startEdit: () => void;
}
export type MessageEditContext = MessageEditState & MessageEditActions;
const MESSAGE_EDIT_KEY = Symbol.for('chat-message-edit');
/**
* Sets the message edit context. Call this in the parent component (ChatMessage.svelte).
*/
export function setMessageEditContext(ctx: MessageEditContext): MessageEditContext {
return setContext(MESSAGE_EDIT_KEY, ctx);
}
/**
* Gets the message edit context. Call this in child components.
*/
export function getMessageEditContext(): MessageEditContext {
return getContext(MESSAGE_EDIT_KEY);
}

View File

@@ -1,4 +1,51 @@
export enum ChatMessageStatsView {
GENERATION = 'generation',
READING = 'reading'
READING = 'reading',
TOOLS = 'tools',
SUMMARY = 'summary'
}
/**
* Reasoning format options for API requests.
*/
export enum ReasoningFormat {
NONE = 'none',
AUTO = 'auto'
}
/**
* Message roles for chat messages.
*/
export enum MessageRole {
USER = 'user',
ASSISTANT = 'assistant',
SYSTEM = 'system',
TOOL = 'tool'
}
/**
* Message types for different content kinds.
*/
export enum MessageType {
ROOT = 'root',
TEXT = 'text',
THINK = 'think',
SYSTEM = 'system'
}
/**
* Content part types for API chat message content.
*/
export enum ContentPartType {
TEXT = 'text',
IMAGE_URL = 'image_url',
INPUT_AUDIO = 'input_audio'
}
/**
* Error dialog types for displaying server/timeout errors.
*/
export enum ErrorDialogType {
TIMEOUT = 'timeout',
SERVER = 'server'
}

View File

@@ -0,0 +1,15 @@
/**
* Keyboard key names for event handling
*/
export enum KeyboardKey {
ENTER = 'Enter',
ESCAPE = 'Escape',
ARROW_UP = 'ArrowUp',
ARROW_DOWN = 'ArrowDown',
TAB = 'Tab',
D_LOWER = 'd',
D_UPPER = 'D',
E_UPPER = 'E',
K_LOWER = 'k',
O_UPPER = 'O'
}

View File

@@ -0,0 +1,26 @@
/**
* Parameter source - indicates whether a parameter uses default or custom value
*/
export enum ParameterSource {
DEFAULT = 'default',
CUSTOM = 'custom'
}
/**
* Syncable parameter type - data types for parameters that can be synced with server
*/
export enum SyncableParameterType {
NUMBER = 'number',
STRING = 'string',
BOOLEAN = 'boolean'
}
/**
* Settings field type - defines the input type for settings fields
*/
export enum SettingsFieldType {
INPUT = 'input',
TEXTAREA = 'textarea',
CHECKBOX = 'checkbox',
SELECT = 'select'
}

View File

@@ -0,0 +1,165 @@
import { AUTO_SCROLL_AT_BOTTOM_THRESHOLD, AUTO_SCROLL_INTERVAL } from '$lib/constants/auto-scroll';
export interface AutoScrollOptions {
/** Whether auto-scroll is disabled globally (e.g., from settings) */
disabled?: boolean;
}
/**
* Creates an auto-scroll controller for a scrollable container.
*
* Features:
* - Auto-scrolls to bottom during streaming/loading
* - Stops auto-scroll when user manually scrolls up
* - Resumes auto-scroll when user scrolls back to bottom
*/
export class AutoScrollController {
private _autoScrollEnabled = $state(true);
private _userScrolledUp = $state(false);
private _lastScrollTop = $state(0);
private _scrollInterval: ReturnType<typeof setInterval> | undefined;
private _scrollTimeout: ReturnType<typeof setTimeout> | undefined;
private _container: HTMLElement | undefined;
private _disabled: boolean;
constructor(options: AutoScrollOptions = {}) {
this._disabled = options.disabled ?? false;
}
get autoScrollEnabled(): boolean {
return this._autoScrollEnabled;
}
get userScrolledUp(): boolean {
return this._userScrolledUp;
}
/**
* Binds the controller to a scrollable container element.
*/
setContainer(container: HTMLElement | undefined): void {
this._container = container;
}
/**
* Updates the disabled state.
*/
setDisabled(disabled: boolean): void {
this._disabled = disabled;
if (disabled) {
this._autoScrollEnabled = false;
this.stopInterval();
}
}
/**
* Handles scroll events to detect user scroll direction and toggle auto-scroll.
*/
handleScroll(): void {
if (this._disabled || !this._container) return;
const { scrollTop, scrollHeight, clientHeight } = this._container;
const distanceFromBottom = scrollHeight - scrollTop - clientHeight;
const isAtBottom = distanceFromBottom < AUTO_SCROLL_AT_BOTTOM_THRESHOLD;
if (scrollTop < this._lastScrollTop && !isAtBottom) {
this._userScrolledUp = true;
this._autoScrollEnabled = false;
} else if (isAtBottom && this._userScrolledUp) {
this._userScrolledUp = false;
this._autoScrollEnabled = true;
}
if (this._scrollTimeout) {
clearTimeout(this._scrollTimeout);
}
this._scrollTimeout = setTimeout(() => {
if (isAtBottom) {
this._userScrolledUp = false;
this._autoScrollEnabled = true;
}
}, AUTO_SCROLL_INTERVAL);
this._lastScrollTop = scrollTop;
}
/**
* Scrolls the container to the bottom.
*/
scrollToBottom(behavior: ScrollBehavior = 'smooth'): void {
if (this._disabled || !this._container) return;
this._container.scrollTo({
top: this._container.scrollHeight,
behavior
});
}
/**
* Enables auto-scroll (e.g., when user sends a message).
*/
enable(): void {
if (this._disabled) return;
this._userScrolledUp = false;
this._autoScrollEnabled = true;
}
/**
* Starts the auto-scroll interval for continuous scrolling during streaming.
*/
startInterval(): void {
if (this._disabled || this._scrollInterval) return;
this._scrollInterval = setInterval(() => {
this.scrollToBottom();
}, AUTO_SCROLL_INTERVAL);
}
/**
* Stops the auto-scroll interval.
*/
stopInterval(): void {
if (this._scrollInterval) {
clearInterval(this._scrollInterval);
this._scrollInterval = undefined;
}
}
/**
* Updates the auto-scroll interval based on streaming state.
* Call this in a $effect to automatically manage the interval.
*/
updateInterval(isStreaming: boolean): void {
if (this._disabled) {
this.stopInterval();
return;
}
if (isStreaming && this._autoScrollEnabled) {
if (!this._scrollInterval) {
this.startInterval();
}
} else {
this.stopInterval();
}
}
/**
* Cleans up resources. Call this in onDestroy or when the component unmounts.
*/
destroy(): void {
this.stopInterval();
if (this._scrollTimeout) {
clearTimeout(this._scrollTimeout);
this._scrollTimeout = undefined;
}
}
}
/**
* Creates a new AutoScrollController instance.
*/
export function createAutoScrollController(options: AutoScrollOptions = {}): AutoScrollController {
return new AutoScrollController(options);
}

View File

@@ -1,7 +1,9 @@
import { activeProcessingState } from '$lib/stores/chat.svelte';
import { config } from '$lib/stores/settings.svelte';
import { STATS_UNITS } from '$lib/constants/processing-info';
import type { ApiProcessingState } from '$lib/types';
export interface LiveProcessingStats {
interface LiveProcessingStats {
tokensProcessed: number;
totalTokens: number;
timeMs: number;
@@ -9,7 +11,7 @@ export interface LiveProcessingStats {
etaSecs?: number;
}
export interface LiveGenerationStats {
interface LiveGenerationStats {
tokensGenerated: number;
timeMs: number;
tokensPerSecond: number;
@@ -18,6 +20,7 @@ export interface LiveGenerationStats {
export interface UseProcessingStateReturn {
readonly processingState: ApiProcessingState | null;
getProcessingDetails(): string[];
getTechnicalDetails(): string[];
getProcessingMessage(): string;
getPromptProgressText(): string | null;
getLiveProcessingStats(): LiveProcessingStats | null;
@@ -138,8 +141,31 @@ export function useProcessingState(): UseProcessingStateReturn {
const details: string[] = [];
// Show prompt processing progress with ETA during preparation phase
if (stateToUse.promptProgress) {
const { processed, total, time_ms, cache } = stateToUse.promptProgress;
const actualProcessed = processed - cache;
const actualTotal = total - cache;
if (actualProcessed < actualTotal && actualProcessed > 0) {
const percent = Math.round((actualProcessed / actualTotal) * 100);
const eta = getETASecs(actualProcessed, actualTotal, time_ms);
if (eta !== undefined) {
const etaSecs = Math.ceil(eta);
details.push(`Processing ${percent}% (ETA: ${etaSecs}s)`);
} else {
details.push(`Processing ${percent}%`);
}
}
}
// Always show context info when we have valid data
if (stateToUse.contextUsed >= 0 && stateToUse.contextTotal > 0) {
if (
typeof stateToUse.contextTotal === 'number' &&
stateToUse.contextUsed >= 0 &&
stateToUse.contextTotal > 0
) {
const contextPercent = Math.round((stateToUse.contextUsed / stateToUse.contextTotal) * 100);
details.push(
@@ -163,7 +189,57 @@ export function useProcessingState(): UseProcessingStateReturn {
}
if (stateToUse.tokensPerSecond && stateToUse.tokensPerSecond > 0) {
details.push(`${stateToUse.tokensPerSecond.toFixed(1)} tokens/sec`);
details.push(`${stateToUse.tokensPerSecond.toFixed(1)} ${STATS_UNITS.TOKENS_PER_SECOND}`);
}
if (stateToUse.speculative) {
details.push('Speculative decoding enabled');
}
return details;
}
/**
* Returns technical details without the progress message (for bottom bar)
*/
function getTechnicalDetails(): string[] {
const stateToUse = processingState || lastKnownState;
if (!stateToUse) {
return [];
}
const details: string[] = [];
// Always show context info when we have valid data
if (
typeof stateToUse.contextTotal === 'number' &&
stateToUse.contextUsed >= 0 &&
stateToUse.contextTotal > 0
) {
const contextPercent = Math.round((stateToUse.contextUsed / stateToUse.contextTotal) * 100);
details.push(
`Context: ${stateToUse.contextUsed}/${stateToUse.contextTotal} (${contextPercent}%)`
);
}
if (stateToUse.outputTokensUsed > 0) {
// Handle infinite max_tokens (-1) case
if (stateToUse.outputTokensMax <= 0) {
details.push(`Output: ${stateToUse.outputTokensUsed}/∞`);
} else {
const outputPercent = Math.round(
(stateToUse.outputTokensUsed / stateToUse.outputTokensMax) * 100
);
details.push(
`Output: ${stateToUse.outputTokensUsed}/${stateToUse.outputTokensMax} (${outputPercent}%)`
);
}
}
if (stateToUse.tokensPerSecond && stateToUse.tokensPerSecond > 0) {
details.push(`${stateToUse.tokensPerSecond.toFixed(1)} ${STATS_UNITS.TOKENS_PER_SECOND}`);
}
if (stateToUse.speculative) {
@@ -251,6 +327,7 @@ export function useProcessingState(): UseProcessingStateReturn {
return processingState;
},
getProcessingDetails,
getTechnicalDetails,
getProcessingMessage,
getPromptProgressText,
getLiveProcessingStats,

View File

@@ -13,6 +13,16 @@
import type { Plugin } from 'unified';
import type { Root, Element, ElementContent } from 'hast';
import { visit } from 'unist-util-visit';
import {
CODE_BLOCK_SCROLL_CONTAINER_CLASS,
CODE_BLOCK_WRAPPER_CLASS,
CODE_BLOCK_HEADER_CLASS,
CODE_BLOCK_ACTIONS_CLASS,
CODE_LANGUAGE_CLASS,
COPY_CODE_BTN_CLASS,
PREVIEW_CODE_BTN_CLASS,
RELATIVE_CLASS
} from '$lib/constants/code-blocks';
declare global {
interface Window {
@@ -42,7 +52,7 @@ function createCopyButton(codeId: string): Element {
type: 'element',
tagName: 'button',
properties: {
className: ['copy-code-btn'],
className: [COPY_CODE_BTN_CLASS],
'data-code-id': codeId,
title: 'Copy code',
type: 'button'
@@ -56,7 +66,7 @@ function createPreviewButton(codeId: string): Element {
type: 'element',
tagName: 'button',
properties: {
className: ['preview-code-btn'],
className: [PREVIEW_CODE_BTN_CLASS],
'data-code-id': codeId,
title: 'Preview code',
type: 'button'
@@ -75,30 +85,39 @@ function createHeader(language: string, codeId: string): Element {
return {
type: 'element',
tagName: 'div',
properties: { className: ['code-block-header'] },
properties: { className: [CODE_BLOCK_HEADER_CLASS] },
children: [
{
type: 'element',
tagName: 'span',
properties: { className: ['code-language'] },
properties: { className: [CODE_LANGUAGE_CLASS] },
children: [{ type: 'text', value: language }]
},
{
type: 'element',
tagName: 'div',
properties: { className: ['code-block-actions'] },
properties: { className: [CODE_BLOCK_ACTIONS_CLASS] },
children: actions
}
]
};
}
function createScrollContainer(preElement: Element): Element {
return {
type: 'element',
tagName: 'div',
properties: { className: [CODE_BLOCK_SCROLL_CONTAINER_CLASS] },
children: [preElement]
};
}
function createWrapper(header: Element, preElement: Element): Element {
return {
type: 'element',
tagName: 'div',
properties: { className: ['code-block-wrapper'] },
children: [header, preElement]
properties: { className: [CODE_BLOCK_WRAPPER_CLASS, RELATIVE_CLASS] },
children: [header, createScrollContainer(preElement)]
};
}

View File

@@ -0,0 +1,368 @@
import Dexie, { type EntityTable } from 'dexie';
import { findDescendantMessages } from '$lib/utils';
class LlamacppDatabase extends Dexie {
conversations!: EntityTable<DatabaseConversation, string>;
messages!: EntityTable<DatabaseMessage, string>;
constructor() {
super('LlamacppWebui');
this.version(1).stores({
conversations: 'id, lastModified, currNode, name',
messages: 'id, convId, type, role, timestamp, parent, children'
});
}
}
const db = new LlamacppDatabase();
import { v4 as uuid } from 'uuid';
import { MessageRole } from '$lib/enums/chat';
export class DatabaseService {
/**
*
*
* Conversations
*
*
*/
/**
* Creates a new conversation.
*
* @param name - Name of the conversation
* @returns The created conversation
*/
static async createConversation(name: string): Promise<DatabaseConversation> {
const conversation: DatabaseConversation = {
id: uuid(),
name,
lastModified: Date.now(),
currNode: ''
};
await db.conversations.add(conversation);
return conversation;
}
/**
*
*
* Messages
*
*
*/
/**
* Creates a new message branch by adding a message and updating parent/child relationships.
* Also updates the conversation's currNode to point to the new message.
*
* @param message - Message to add (without id)
* @param parentId - Parent message ID to attach to
* @returns The created message
*/
static async createMessageBranch(
message: Omit<DatabaseMessage, 'id'>,
parentId: string | null
): Promise<DatabaseMessage> {
return await db.transaction('rw', [db.conversations, db.messages], async () => {
// Handle null parent (root message case)
if (parentId !== null) {
const parentMessage = await db.messages.get(parentId);
if (!parentMessage) {
throw new Error(`Parent message ${parentId} not found`);
}
}
const newMessage: DatabaseMessage = {
...message,
id: uuid(),
parent: parentId,
toolCalls: message.toolCalls ?? '',
children: []
};
await db.messages.add(newMessage);
// Update parent's children array if parent exists
if (parentId !== null) {
const parentMessage = await db.messages.get(parentId);
if (parentMessage) {
await db.messages.update(parentId, {
children: [...parentMessage.children, newMessage.id]
});
}
}
await this.updateConversation(message.convId, {
currNode: newMessage.id
});
return newMessage;
});
}
/**
* Creates a root message for a new conversation.
* Root messages are not displayed but serve as the tree root for branching.
*
* @param convId - Conversation ID
* @returns The created root message
*/
static async createRootMessage(convId: string): Promise<string> {
const rootMessage: DatabaseMessage = {
id: uuid(),
convId,
type: 'root',
timestamp: Date.now(),
role: MessageRole.SYSTEM,
content: '',
parent: null,
toolCalls: '',
children: []
};
await db.messages.add(rootMessage);
return rootMessage.id;
}
/**
* Creates a system prompt message for a conversation.
*
* @param convId - Conversation ID
* @param systemPrompt - The system prompt content (must be non-empty)
* @param parentId - Parent message ID (typically the root message)
* @returns The created system message
* @throws Error if systemPrompt is empty
*/
static async createSystemMessage(
convId: string,
systemPrompt: string,
parentId: string
): Promise<DatabaseMessage> {
const trimmedPrompt = systemPrompt.trim();
if (!trimmedPrompt) {
throw new Error('Cannot create system message with empty content');
}
const systemMessage: DatabaseMessage = {
id: uuid(),
convId,
type: MessageRole.SYSTEM,
timestamp: Date.now(),
role: MessageRole.SYSTEM,
content: trimmedPrompt,
parent: parentId,
children: []
};
await db.messages.add(systemMessage);
const parentMessage = await db.messages.get(parentId);
if (parentMessage) {
await db.messages.update(parentId, {
children: [...parentMessage.children, systemMessage.id]
});
}
return systemMessage;
}
/**
* Deletes a conversation and all its messages.
*
* @param id - Conversation ID
*/
static async deleteConversation(id: string): Promise<void> {
await db.transaction('rw', [db.conversations, db.messages], async () => {
await db.conversations.delete(id);
await db.messages.where('convId').equals(id).delete();
});
}
/**
* Deletes a message and removes it from its parent's children array.
*
* @param messageId - ID of the message to delete
*/
static async deleteMessage(messageId: string): Promise<void> {
await db.transaction('rw', db.messages, async () => {
const message = await db.messages.get(messageId);
if (!message) return;
// Remove this message from its parent's children array
if (message.parent) {
const parent = await db.messages.get(message.parent);
if (parent) {
parent.children = parent.children.filter((childId: string) => childId !== messageId);
await db.messages.put(parent);
}
}
// Delete the message
await db.messages.delete(messageId);
});
}
/**
* Deletes a message and all its descendant messages (cascading deletion).
* This removes the entire branch starting from the specified message.
*
* @param conversationId - ID of the conversation containing the message
* @param messageId - ID of the root message to delete (along with all descendants)
* @returns Array of all deleted message IDs
*/
static async deleteMessageCascading(
conversationId: string,
messageId: string
): Promise<string[]> {
return await db.transaction('rw', db.messages, async () => {
// Get all messages in the conversation to find descendants
const allMessages = await db.messages.where('convId').equals(conversationId).toArray();
// Find all descendant messages
const descendants = findDescendantMessages(allMessages, messageId);
const allToDelete = [messageId, ...descendants];
// Get the message to delete for parent cleanup
const message = await db.messages.get(messageId);
if (message && message.parent) {
const parent = await db.messages.get(message.parent);
if (parent) {
parent.children = parent.children.filter((childId: string) => childId !== messageId);
await db.messages.put(parent);
}
}
// Delete all messages in the branch
await db.messages.bulkDelete(allToDelete);
return allToDelete;
});
}
/**
* Gets all conversations, sorted by last modified time (newest first).
*
* @returns Array of conversations
*/
static async getAllConversations(): Promise<DatabaseConversation[]> {
return await db.conversations.orderBy('lastModified').reverse().toArray();
}
/**
* Gets a conversation by ID.
*
* @param id - Conversation ID
* @returns The conversation if found, otherwise undefined
*/
static async getConversation(id: string): Promise<DatabaseConversation | undefined> {
return await db.conversations.get(id);
}
/**
* Gets all messages in a conversation, sorted by timestamp (oldest first).
*
* @param convId - Conversation ID
* @returns Array of messages in the conversation
*/
static async getConversationMessages(convId: string): Promise<DatabaseMessage[]> {
return await db.messages.where('convId').equals(convId).sortBy('timestamp');
}
/**
* Updates a conversation.
*
* @param id - Conversation ID
* @param updates - Partial updates to apply
* @returns Promise that resolves when the conversation is updated
*/
static async updateConversation(
id: string,
updates: Partial<Omit<DatabaseConversation, 'id'>>
): Promise<void> {
await db.conversations.update(id, {
...updates,
lastModified: Date.now()
});
}
/**
*
*
* Navigation
*
*
*/
/**
* Updates the conversation's current node (active branch).
* This determines which conversation path is currently being viewed.
*
* @param convId - Conversation ID
* @param nodeId - Message ID to set as current node
*/
static async updateCurrentNode(convId: string, nodeId: string): Promise<void> {
await this.updateConversation(convId, {
currNode: nodeId
});
}
/**
* Updates a message.
*
* @param id - Message ID
* @param updates - Partial updates to apply
* @returns Promise that resolves when the message is updated
*/
static async updateMessage(
id: string,
updates: Partial<Omit<DatabaseMessage, 'id'>>
): Promise<void> {
await db.messages.update(id, updates);
}
/**
*
*
* Import
*
*
*/
/**
* Imports multiple conversations and their messages.
* Skips conversations that already exist.
*
* @param data - Array of { conv, messages } objects
*/
static async importConversations(
data: { conv: DatabaseConversation; messages: DatabaseMessage[] }[]
): Promise<{ imported: number; skipped: number }> {
let importedCount = 0;
let skippedCount = 0;
return await db.transaction('rw', [db.conversations, db.messages], async () => {
for (const item of data) {
const { conv, messages } = item;
const existing = await db.conversations.get(conv.id);
if (existing) {
console.warn(`Conversation "${conv.name}" already exists, skipping...`);
skippedCount++;
continue;
}
await db.conversations.add(conv);
for (const msg of messages) {
await db.messages.put(msg);
}
importedCount++;
}
return { imported: importedCount, skipped: skippedCount };
});
}
}

View File

@@ -0,0 +1,99 @@
import { ServerModelStatus } from '$lib/enums';
import { apiFetch, apiPost } from '$lib/utils/api-fetch';
export class ModelsService {
/**
*
*
* Listing
*
*
*/
/**
* Fetch list of models from OpenAI-compatible endpoint.
* Works in both MODEL and ROUTER modes.
*
* @returns List of available models with basic metadata
*/
static async list(): Promise<ApiModelListResponse> {
return apiFetch<ApiModelListResponse>('/v1/models');
}
/**
* Fetch list of all models with detailed metadata (ROUTER mode).
* Returns models with load status, paths, and other metadata
* beyond what the OpenAI-compatible endpoint provides.
*
* @returns List of models with detailed status and configuration info
*/
static async listRouter(): Promise<ApiRouterModelsListResponse> {
return apiFetch<ApiRouterModelsListResponse>('/v1/models');
}
/**
*
*
* Load/Unload
*
*
*/
/**
* Load a model (ROUTER mode only).
* Sends POST request to `/models/load`. Note: the endpoint returns success
* before loading completes — use polling to await actual load status.
*
* @param modelId - Model identifier to load
* @param extraArgs - Optional additional arguments to pass to the model instance
* @returns Load response from the server
*/
static async load(modelId: string, extraArgs?: string[]): Promise<ApiRouterModelsLoadResponse> {
const payload: { model: string; extra_args?: string[] } = { model: modelId };
if (extraArgs && extraArgs.length > 0) {
payload.extra_args = extraArgs;
}
return apiPost<ApiRouterModelsLoadResponse>('/models/load', payload);
}
/**
* Unload a model (ROUTER mode only).
* Sends POST request to `/models/unload`. Note: the endpoint returns success
* before unloading completes — use polling to await actual unload status.
*
* @param modelId - Model identifier to unload
* @returns Unload response from the server
*/
static async unload(modelId: string): Promise<ApiRouterModelsUnloadResponse> {
return apiPost<ApiRouterModelsUnloadResponse>('/models/unload', { model: modelId });
}
/**
*
*
* Status
*
*
*/
/**
* Check if a model is loaded based on its metadata.
*
* @param model - Model data entry from the API response
* @returns True if the model status is LOADED
*/
static isModelLoaded(model: ApiModelDataEntry): boolean {
return model.status.value === ServerModelStatus.LOADED;
}
/**
* Check if a model is currently loading.
*
* @param model - Model data entry from the API response
* @returns True if the model status is LOADING
*/
static isModelLoading(model: ApiModelDataEntry): boolean {
return model.status.value === ServerModelStatus.LOADING;
}
}

View File

@@ -0,0 +1,148 @@
import { describe, it, expect } from 'vitest';
import { ParameterSyncService } from './parameter-sync.service';
describe('ParameterSyncService', () => {
describe('roundFloatingPoint', () => {
it('should fix JavaScript floating-point precision issues', () => {
// Test the specific values from the screenshot
const mockServerParams = {
top_p: 0.949999988079071,
min_p: 0.009999999776482582,
temperature: 0.800000011920929,
top_k: 40,
samplers: ['top_k', 'typ_p', 'top_p', 'min_p', 'temperature']
};
const result = ParameterSyncService.extractServerDefaults({
...mockServerParams,
// Add other required fields to match the API type
n_predict: 512,
seed: -1,
dynatemp_range: 0.0,
dynatemp_exponent: 1.0,
xtc_probability: 0.0,
xtc_threshold: 0.1,
typ_p: 1.0,
repeat_last_n: 64,
repeat_penalty: 1.0,
presence_penalty: 0.0,
frequency_penalty: 0.0,
dry_multiplier: 0.0,
dry_base: 1.75,
dry_allowed_length: 2,
dry_penalty_last_n: -1,
mirostat: 0,
mirostat_tau: 5.0,
mirostat_eta: 0.1,
stop: [],
max_tokens: -1,
n_keep: 0,
n_discard: 0,
ignore_eos: false,
stream: true,
logit_bias: [],
n_probs: 0,
min_keep: 0,
grammar: '',
grammar_lazy: false,
grammar_triggers: [],
preserved_tokens: [],
chat_format: '',
reasoning_format: '',
reasoning_in_content: false,
thinking_forced_open: false,
'speculative.n_max': 0,
'speculative.n_min': 0,
'speculative.p_min': 0.0,
timings_per_token: false,
post_sampling_probs: false,
lora: [],
top_n_sigma: 0.0,
dry_sequence_breakers: []
} as ApiLlamaCppServerProps['default_generation_settings']['params']);
// Check that the problematic floating-point values are rounded correctly
expect(result.top_p).toBe(0.95);
expect(result.min_p).toBe(0.01);
expect(result.temperature).toBe(0.8);
expect(result.top_k).toBe(40); // Integer should remain unchanged
expect(result.samplers).toBe('top_k;typ_p;top_p;min_p;temperature');
});
it('should preserve non-numeric values', () => {
const mockServerParams = {
samplers: ['top_k', 'temperature'],
max_tokens: -1,
temperature: 0.7
};
const result = ParameterSyncService.extractServerDefaults({
...mockServerParams,
// Minimal required fields
n_predict: 512,
seed: -1,
dynatemp_range: 0.0,
dynatemp_exponent: 1.0,
top_k: 40,
top_p: 0.95,
min_p: 0.05,
xtc_probability: 0.0,
xtc_threshold: 0.1,
typ_p: 1.0,
repeat_last_n: 64,
repeat_penalty: 1.0,
presence_penalty: 0.0,
frequency_penalty: 0.0,
dry_multiplier: 0.0,
dry_base: 1.75,
dry_allowed_length: 2,
dry_penalty_last_n: -1,
mirostat: 0,
mirostat_tau: 5.0,
mirostat_eta: 0.1,
stop: [],
n_keep: 0,
n_discard: 0,
ignore_eos: false,
stream: true,
logit_bias: [],
n_probs: 0,
min_keep: 0,
grammar: '',
grammar_lazy: false,
grammar_triggers: [],
preserved_tokens: [],
chat_format: '',
reasoning_format: '',
reasoning_in_content: false,
thinking_forced_open: false,
'speculative.n_max': 0,
'speculative.n_min': 0,
'speculative.p_min': 0.0,
timings_per_token: false,
post_sampling_probs: false,
lora: [],
top_n_sigma: 0.0,
dry_sequence_breakers: []
} as ApiLlamaCppServerProps['default_generation_settings']['params']);
expect(result.samplers).toBe('top_k;temperature');
expect(result.max_tokens).toBe(-1);
expect(result.temperature).toBe(0.7);
});
it('should merge webui settings from props when provided', () => {
const result = ParameterSyncService.extractServerDefaults(null, {
pasteLongTextToFileLen: 0,
pdfAsImage: true,
renderUserContentAsMarkdown: false,
theme: 'dark'
});
expect(result.pasteLongTextToFileLen).toBe(0);
expect(result.pdfAsImage).toBe(true);
expect(result.renderUserContentAsMarkdown).toBe(false);
expect(result.theme).toBeUndefined();
});
});
});

View File

@@ -0,0 +1,400 @@
import { normalizeFloatingPoint } from '$lib/utils';
import { SyncableParameterType, ParameterSource } from '$lib/enums/settings';
type ParameterValue = string | number | boolean;
type ParameterRecord = Record<string, ParameterValue>;
interface ParameterInfo {
value: string | number | boolean;
source: ParameterSource;
serverDefault?: string | number | boolean;
userOverride?: string | number | boolean;
}
interface SyncableParameter {
key: string;
serverKey: string;
type: SyncableParameterType;
canSync: boolean;
}
/**
* Mapping of webui setting keys to server parameter keys.
* Only parameters listed here can be synced from the server `/props` endpoint.
* Each entry defines the webui key, corresponding server key, value type,
* and whether sync is enabled.
*/
export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
{
key: 'temperature',
serverKey: 'temperature',
type: SyncableParameterType.NUMBER,
canSync: true
},
{ key: 'top_k', serverKey: 'top_k', type: SyncableParameterType.NUMBER, canSync: true },
{ key: 'top_p', serverKey: 'top_p', type: SyncableParameterType.NUMBER, canSync: true },
{ key: 'min_p', serverKey: 'min_p', type: SyncableParameterType.NUMBER, canSync: true },
{
key: 'dynatemp_range',
serverKey: 'dynatemp_range',
type: SyncableParameterType.NUMBER,
canSync: true
},
{
key: 'dynatemp_exponent',
serverKey: 'dynatemp_exponent',
type: SyncableParameterType.NUMBER,
canSync: true
},
{
key: 'xtc_probability',
serverKey: 'xtc_probability',
type: SyncableParameterType.NUMBER,
canSync: true
},
{
key: 'xtc_threshold',
serverKey: 'xtc_threshold',
type: SyncableParameterType.NUMBER,
canSync: true
},
{ key: 'typ_p', serverKey: 'typ_p', type: SyncableParameterType.NUMBER, canSync: true },
{
key: 'repeat_last_n',
serverKey: 'repeat_last_n',
type: SyncableParameterType.NUMBER,
canSync: true
},
{
key: 'repeat_penalty',
serverKey: 'repeat_penalty',
type: SyncableParameterType.NUMBER,
canSync: true
},
{
key: 'presence_penalty',
serverKey: 'presence_penalty',
type: SyncableParameterType.NUMBER,
canSync: true
},
{
key: 'frequency_penalty',
serverKey: 'frequency_penalty',
type: SyncableParameterType.NUMBER,
canSync: true
},
{
key: 'dry_multiplier',
serverKey: 'dry_multiplier',
type: SyncableParameterType.NUMBER,
canSync: true
},
{ key: 'dry_base', serverKey: 'dry_base', type: SyncableParameterType.NUMBER, canSync: true },
{
key: 'dry_allowed_length',
serverKey: 'dry_allowed_length',
type: SyncableParameterType.NUMBER,
canSync: true
},
{
key: 'dry_penalty_last_n',
serverKey: 'dry_penalty_last_n',
type: SyncableParameterType.NUMBER,
canSync: true
},
{ key: 'max_tokens', serverKey: 'max_tokens', type: SyncableParameterType.NUMBER, canSync: true },
{ key: 'samplers', serverKey: 'samplers', type: SyncableParameterType.STRING, canSync: true },
{
key: 'pasteLongTextToFileLen',
serverKey: 'pasteLongTextToFileLen',
type: SyncableParameterType.NUMBER,
canSync: true
},
{
key: 'pdfAsImage',
serverKey: 'pdfAsImage',
type: SyncableParameterType.BOOLEAN,
canSync: true
},
{
key: 'showThoughtInProgress',
serverKey: 'showThoughtInProgress',
type: SyncableParameterType.BOOLEAN,
canSync: true
},
{
key: 'keepStatsVisible',
serverKey: 'keepStatsVisible',
type: SyncableParameterType.BOOLEAN,
canSync: true
},
{
key: 'showMessageStats',
serverKey: 'showMessageStats',
type: SyncableParameterType.BOOLEAN,
canSync: true
},
{
key: 'askForTitleConfirmation',
serverKey: 'askForTitleConfirmation',
type: SyncableParameterType.BOOLEAN,
canSync: true
},
{
key: 'disableAutoScroll',
serverKey: 'disableAutoScroll',
type: SyncableParameterType.BOOLEAN,
canSync: true
},
{
key: 'renderUserContentAsMarkdown',
serverKey: 'renderUserContentAsMarkdown',
type: SyncableParameterType.BOOLEAN,
canSync: true
},
{
key: 'autoMicOnEmpty',
serverKey: 'autoMicOnEmpty',
type: SyncableParameterType.BOOLEAN,
canSync: true
},
{
key: 'pyInterpreterEnabled',
serverKey: 'pyInterpreterEnabled',
type: SyncableParameterType.BOOLEAN,
canSync: true
},
{
key: 'enableContinueGeneration',
serverKey: 'enableContinueGeneration',
type: SyncableParameterType.BOOLEAN,
canSync: true
}
];
export class ParameterSyncService {
/**
*
*
* Extraction
*
*
*/
/**
* Round floating-point numbers to avoid JavaScript precision issues.
* E.g., 0.1 + 0.2 = 0.30000000000000004 → 0.3
*
* @param value - Parameter value to normalize
* @returns Precision-normalized value
*/
private static roundFloatingPoint(value: ParameterValue): ParameterValue {
return normalizeFloatingPoint(value) as ParameterValue;
}
/**
* Extract server default parameters that can be synced from `/props` response.
* Handles both generation settings parameters and webui-specific settings.
* Converts samplers array to semicolon-delimited string for UI display.
*
* @param serverParams - Raw generation settings from server `/props` endpoint
* @param webuiSettings - Optional webui-specific settings from server
* @returns Record of extracted parameter key-value pairs with normalized precision
*/
static extractServerDefaults(
serverParams: ApiLlamaCppServerProps['default_generation_settings']['params'] | null,
webuiSettings?: Record<string, string | number | boolean>
): ParameterRecord {
const extracted: ParameterRecord = {};
if (serverParams) {
for (const param of SYNCABLE_PARAMETERS) {
if (param.canSync && param.serverKey in serverParams) {
const value = (serverParams as unknown as Record<string, ParameterValue>)[
param.serverKey
];
if (value !== undefined) {
// Apply precision rounding to avoid JavaScript floating-point issues
extracted[param.key] = this.roundFloatingPoint(value);
}
}
}
// Handle samplers array conversion to string
if (serverParams.samplers && Array.isArray(serverParams.samplers)) {
extracted.samplers = serverParams.samplers.join(';');
}
}
if (webuiSettings) {
for (const param of SYNCABLE_PARAMETERS) {
if (param.canSync && param.serverKey in webuiSettings) {
const value = webuiSettings[param.serverKey];
if (value !== undefined) {
extracted[param.key] = this.roundFloatingPoint(value);
}
}
}
}
return extracted;
}
/**
*
*
* Merging
*
*
*/
/**
* Merge server defaults with current user settings.
* User overrides always take priority — only parameters not in `userOverrides`
* set will be updated from server defaults.
*
* @param currentSettings - Current parameter values in the settings store
* @param serverDefaults - Default values extracted from server props
* @param userOverrides - Set of parameter keys explicitly overridden by the user
* @returns Merged parameter record with user overrides preserved
*/
static mergeWithServerDefaults(
currentSettings: ParameterRecord,
serverDefaults: ParameterRecord,
userOverrides: Set<string> = new Set()
): ParameterRecord {
const merged = { ...currentSettings };
for (const [key, serverValue] of Object.entries(serverDefaults)) {
// Only update if user hasn't explicitly overridden this parameter
if (!userOverrides.has(key)) {
merged[key] = this.roundFloatingPoint(serverValue);
}
}
return merged;
}
/**
*
*
* Info
*
*
*/
/**
* Get parameter information including source and values.
* Used by ChatSettingsParameterSourceIndicator to display the correct badge
* (Custom vs Default) for each parameter in the settings UI.
*
* @param key - The parameter key to get info for
* @param currentValue - The current value of the parameter
* @param propsDefaults - Server default values from `/props`
* @param userOverrides - Set of parameter keys explicitly overridden by the user
* @returns Parameter info with source, server default, and user override values
*/
static getParameterInfo(
key: string,
currentValue: ParameterValue,
propsDefaults: ParameterRecord,
userOverrides: Set<string>
): ParameterInfo {
const hasPropsDefault = propsDefaults[key] !== undefined;
const isUserOverride = userOverrides.has(key);
// Simple logic: either using default (from props) or custom (user override)
const source = isUserOverride ? ParameterSource.CUSTOM : ParameterSource.DEFAULT;
return {
value: currentValue,
source,
serverDefault: hasPropsDefault ? propsDefaults[key] : undefined, // Keep same field name for compatibility
userOverride: isUserOverride ? currentValue : undefined
};
}
/**
* Check if a parameter can be synced from server.
*
* @param key - The parameter key to check
* @returns True if the parameter is in the syncable parameters list
*/
static canSyncParameter(key: string): boolean {
return SYNCABLE_PARAMETERS.some((param) => param.key === key && param.canSync);
}
/**
* Get all syncable parameter keys.
*
* @returns Array of parameter keys that can be synced from server
*/
static getSyncableParameterKeys(): string[] {
return SYNCABLE_PARAMETERS.filter((param) => param.canSync).map((param) => param.key);
}
/**
* Validate a server parameter value against its expected type.
*
* @param key - The parameter key to validate
* @param value - The value to validate
* @returns True if value matches the expected type for this parameter
*/
static validateServerParameter(key: string, value: ParameterValue): boolean {
const param = SYNCABLE_PARAMETERS.find((p) => p.key === key);
if (!param) return false;
switch (param.type) {
case SyncableParameterType.NUMBER:
return typeof value === 'number' && !isNaN(value);
case SyncableParameterType.STRING:
return typeof value === 'string';
case SyncableParameterType.BOOLEAN:
return typeof value === 'boolean';
default:
return false;
}
}
/**
*
*
* Diff
*
*
*/
/**
* Create a diff between current settings and server defaults.
* Shows which parameters differ from server values, useful for debugging
* and for the "Reset to defaults" functionality.
*
* @param currentSettings - Current parameter values in the settings store
* @param serverDefaults - Default values extracted from server props
* @returns Record of parameter diffs with current value, server value, and whether they differ
*/
static createParameterDiff(
currentSettings: ParameterRecord,
serverDefaults: ParameterRecord
): Record<string, { current: ParameterValue; server: ParameterValue; differs: boolean }> {
const diff: Record<
string,
{ current: ParameterValue; server: ParameterValue; differs: boolean }
> = {};
for (const key of this.getSyncableParameterKeys()) {
const currentValue = currentSettings[key];
const serverValue = serverDefaults[key];
if (serverValue !== undefined) {
diff[key] = {
current: currentValue,
server: serverValue,
differs: currentValue !== serverValue
};
}
}
return diff;
}
}

View File

@@ -0,0 +1,47 @@
import { apiFetchWithParams } from '$lib/utils/api-fetch';
export class PropsService {
/**
*
*
* Fetching
*
*
*/
/**
* Fetches global server properties from the `/props` endpoint.
* In MODEL mode, returns modalities for the single loaded model.
* In ROUTER mode, returns server-wide settings without model-specific modalities.
*
* @param autoload - If false, prevents automatic model loading (default: false)
* @returns Server properties including default generation settings and capabilities
* @throws {Error} If the request fails or returns invalid data
*/
static async fetch(autoload = false): Promise<ApiLlamaCppServerProps> {
const params: Record<string, string> = {};
if (!autoload) {
params.autoload = 'false';
}
return apiFetchWithParams<ApiLlamaCppServerProps>('./props', params, { authOnly: true });
}
/**
* Fetches server properties for a specific model (ROUTER mode only).
* Required in ROUTER mode because global `/props` does not include per-model modalities.
*
* @param modelId - The model ID to fetch properties for
* @param autoload - If false, prevents automatic model loading (default: false)
* @returns Server properties specific to the requested model
* @throws {Error} If the request fails, model not found, or model not loaded
*/
static async fetchForModel(modelId: string, autoload = false): Promise<ApiLlamaCppServerProps> {
const params: Record<string, string> = { model: modelId };
if (!autoload) {
params.autoload = 'false';
}
return apiFetchWithParams<ApiLlamaCppServerProps>('./props', params, { authOnly: true });
}
}

View File

@@ -1,8 +1,19 @@
import type { ServerModelStatus, ServerRole } from '$lib/enums';
import type { ChatMessagePromptProgress } from './chat';
import type { ContentPartType, ServerModelStatus, ServerRole } from '$lib/enums';
import type { ChatMessagePromptProgress, ChatRole } from './chat';
export interface ApiChatCompletionToolFunction {
name: string;
description?: string;
parameters: Record<string, unknown>;
}
export interface ApiChatCompletionTool {
type: 'function';
function: ApiChatCompletionToolFunction;
}
export interface ApiChatMessageContentPart {
type: 'text' | 'image_url' | 'input_audio';
type: ContentPartType;
text?: string;
image_url?: {
url: string;
@@ -34,6 +45,8 @@ export interface ApiErrorResponse {
export interface ApiChatMessageData {
role: ChatRole;
content: string | ApiChatMessageContentPart[];
tool_calls?: ApiChatCompletionToolCall[];
tool_call_id?: string;
timestamp?: number;
}
@@ -188,6 +201,7 @@ export interface ApiChatCompletionRequest {
stream?: boolean;
model?: string;
return_progress?: boolean;
tools?: ApiChatCompletionTool[];
// Reasoning parameters
reasoning_format?: string;
// Generation parameters
@@ -247,6 +261,7 @@ export interface ApiChatCompletionStreamChunk {
model?: string;
tool_calls?: ApiChatCompletionToolCallDelta[];
};
finish_reason?: string | null;
}>;
timings?: {
prompt_n?: number;
@@ -267,8 +282,9 @@ export interface ApiChatCompletionResponse {
content: string;
reasoning_content?: string;
model?: string;
tool_calls?: ApiChatCompletionToolCallDelta[];
tool_calls?: ApiChatCompletionToolCall[];
};
finish_reason?: string | null;
}>;
}
@@ -335,7 +351,7 @@ export interface ApiProcessingState {
tokensDecoded: number;
tokensRemaining: number;
contextUsed: number;
contextTotal: number;
contextTotal: number | null;
outputTokensUsed: number; // Total output tokens (thinking + regular content)
outputTokensMax: number; // Max output tokens allowed
temperature: number;

View File

@@ -1,8 +1,5 @@
import type { ApiModelDataEntry, ApiModelDetails } from '$lib/types/api';
/**
* Model modalities - vision and audio capabilities
*/
export interface ModelModalities {
vision: boolean;
audio: boolean;
@@ -14,8 +11,15 @@ export interface ModelOption {
model: string;
description?: string;
capabilities: string[];
/** Model modalities from /props endpoint */
modalities?: ModelModalities;
details?: ApiModelDetails['details'];
meta?: ApiModelDataEntry['meta'];
}
/**
* Modality capabilities for file validation
*/
export interface ModalityCapabilities {
hasVision: boolean;
hasAudio: boolean;
}

View File

@@ -0,0 +1,151 @@
/**
* Abort Signal Utilities
*
* Provides utilities for consistent AbortSignal propagation across the application.
* These utilities help ensure that async operations can be properly cancelled
* when needed (e.g., user stops generation, navigates away, etc.).
*/
/**
* Throws an AbortError if the signal is aborted.
* Use this at the start of async operations to fail fast.
*
* @param signal - Optional AbortSignal to check
* @throws DOMException with name 'AbortError' if signal is aborted
*
* @example
* ```ts
* async function fetchData(signal?: AbortSignal) {
* throwIfAborted(signal);
* // ... proceed with operation
* }
* ```
*/
export function throwIfAborted(signal?: AbortSignal): void {
if (signal?.aborted) {
throw new DOMException('Operation was aborted', 'AbortError');
}
}
/**
* Checks if an error is an AbortError.
* Use this to distinguish between user-initiated cancellation and actual errors.
*
* @param error - Error to check
* @returns true if the error is an AbortError
*
* @example
* ```ts
* try {
* await fetchData(signal);
* } catch (error) {
* if (isAbortError(error)) {
* // User cancelled - no error dialog needed
* return;
* }
* // Handle actual error
* }
* ```
*/
export function isAbortError(error: unknown): boolean {
if (error instanceof DOMException && error.name === 'AbortError') {
return true;
}
if (error instanceof Error && error.name === 'AbortError') {
return true;
}
return false;
}
/**
* Creates a new AbortController that is linked to one or more parent signals.
* When any parent signal aborts, the returned controller also aborts.
*
* Useful for creating child operations that should be cancelled when
* either the parent operation or their own timeout/condition triggers.
*
* @param signals - Parent signals to link to (undefined signals are ignored)
* @returns A new AbortController linked to all provided signals
*
* @example
* ```ts
* // Link to user's abort signal and add a timeout
* const linked = createLinkedController(userSignal, timeoutSignal);
* await fetch(url, { signal: linked.signal });
* ```
*/
export function createLinkedController(...signals: (AbortSignal | undefined)[]): AbortController {
const controller = new AbortController();
for (const signal of signals) {
if (!signal) continue;
// If already aborted, abort immediately
if (signal.aborted) {
controller.abort(signal.reason);
return controller;
}
// Link to parent signal
signal.addEventListener('abort', () => controller.abort(signal.reason), { once: true });
}
return controller;
}
/**
* Creates an AbortSignal that times out after the specified duration.
*
* @param ms - Timeout duration in milliseconds
* @returns AbortSignal that will abort after the timeout
*
* @example
* ```ts
* const signal = createTimeoutSignal(5000); // 5 second timeout
* await fetch(url, { signal });
* ```
*/
export function createTimeoutSignal(ms: number): AbortSignal {
return AbortSignal.timeout(ms);
}
/**
* Wraps a promise to reject if the signal is aborted.
* Useful for making non-abortable promises respect an AbortSignal.
*
* @param promise - Promise to wrap
* @param signal - AbortSignal to respect
* @returns Promise that rejects with AbortError if signal aborts
*
* @example
* ```ts
* // Make a non-abortable operation respect abort signal
* const result = await withAbortSignal(
* someNonAbortableOperation(),
* signal
* );
* ```
*/
export async function withAbortSignal<T>(promise: Promise<T>, signal?: AbortSignal): Promise<T> {
if (!signal) return promise;
throwIfAborted(signal);
return new Promise<T>((resolve, reject) => {
const abortHandler = () => {
reject(new DOMException('Operation was aborted', 'AbortError'));
};
signal.addEventListener('abort', abortHandler, { once: true });
promise
.then((value) => {
signal.removeEventListener('abort', abortHandler);
resolve(value);
})
.catch((error) => {
signal.removeEventListener('abort', abortHandler);
reject(error);
});
});
}

View File

@@ -0,0 +1,154 @@
import { base } from '$app/paths';
import { getJsonHeaders, getAuthHeaders } from './api-headers';
/**
* API Fetch Utilities
*
* Provides common fetch patterns used across services:
* - Automatic JSON headers
* - Error handling with proper error messages
* - Base path resolution
*/
export interface ApiFetchOptions extends Omit<RequestInit, 'headers'> {
/**
* Use auth-only headers (no Content-Type).
* Default: false (uses JSON headers with Content-Type: application/json)
*/
authOnly?: boolean;
/**
* Additional headers to merge with default headers.
*/
headers?: Record<string, string>;
}
/**
* Fetch JSON data from an API endpoint with standard headers and error handling.
*
* @param path - API path (will be prefixed with base path)
* @param options - Fetch options with additional authOnly flag
* @returns Parsed JSON response
* @throws Error with formatted message on failure
*
* @example
* ```typescript
* // GET request
* const models = await apiFetch<ApiModelListResponse>('/v1/models');
*
* // POST request
* const result = await apiFetch<ApiResponse>('/models/load', {
* method: 'POST',
* body: JSON.stringify({ model: 'gpt-4' })
* });
* ```
*/
export async function apiFetch<T>(path: string, options: ApiFetchOptions = {}): Promise<T> {
const { authOnly = false, headers: customHeaders, ...fetchOptions } = options;
const baseHeaders = authOnly ? getAuthHeaders() : getJsonHeaders();
const headers = { ...baseHeaders, ...customHeaders };
const url = path.startsWith('http://') || path.startsWith('https://') ? path : `${base}${path}`;
const response = await fetch(url, {
...fetchOptions,
headers
});
if (!response.ok) {
const errorMessage = await parseErrorMessage(response);
throw new Error(errorMessage);
}
return response.json() as Promise<T>;
}
/**
* Fetch with URL constructed from base URL and query parameters.
*
* @param basePath - Base API path
* @param params - Query parameters to append
* @param options - Fetch options
* @returns Parsed JSON response
*
* @example
* ```typescript
* const props = await apiFetchWithParams<ApiProps>('./props', {
* model: 'gpt-4',
* autoload: 'false'
* });
* ```
*/
export async function apiFetchWithParams<T>(
basePath: string,
params: Record<string, string>,
options: ApiFetchOptions = {}
): Promise<T> {
const url = new URL(basePath, window.location.href);
for (const [key, value] of Object.entries(params)) {
if (value !== undefined && value !== null) {
url.searchParams.set(key, value);
}
}
const { authOnly = false, headers: customHeaders, ...fetchOptions } = options;
const baseHeaders = authOnly ? getAuthHeaders() : getJsonHeaders();
const headers = { ...baseHeaders, ...customHeaders };
const response = await fetch(url.toString(), {
...fetchOptions,
headers
});
if (!response.ok) {
const errorMessage = await parseErrorMessage(response);
throw new Error(errorMessage);
}
return response.json() as Promise<T>;
}
/**
* POST JSON data to an API endpoint.
*
* @param path - API path
* @param body - Request body (will be JSON stringified)
* @param options - Additional fetch options
* @returns Parsed JSON response
*/
export async function apiPost<T, B = unknown>(
path: string,
body: B,
options: ApiFetchOptions = {}
): Promise<T> {
return apiFetch<T>(path, {
method: 'POST',
body: JSON.stringify(body),
...options
});
}
/**
* Parse error message from a failed response.
* Tries to extract error message from JSON body, falls back to status text.
*/
async function parseErrorMessage(response: Response): Promise<string> {
try {
const errorData = await response.json();
if (errorData?.error?.message) {
return errorData.error.message;
}
if (errorData?.error && typeof errorData.error === 'string') {
return errorData.error;
}
if (errorData?.message) {
return errorData.message;
}
} catch {
// JSON parsing failed, use status text
}
return `Request failed: ${response.status} ${response.statusText}`;
}

View File

@@ -15,6 +15,8 @@
* └── message 5 (assistant)
*/
import { MessageRole } from '$lib/enums/chat';
/**
* Filters messages to get the conversation path from root to a specific leaf node.
* If the leafNodeId doesn't exist, returns the path with the latest timestamp.
@@ -65,8 +67,13 @@ export function filterByLeafNodeId(
currentNode = nodeMap.get(currentNode.parent);
}
// Sort by timestamp to get chronological order (root to leaf)
result.sort((a, b) => a.timestamp - b.timestamp);
// Sort: system messages first, then by timestamp
result.sort((a, b) => {
if (a.role === MessageRole.SYSTEM && b.role !== MessageRole.SYSTEM) return -1;
if (a.role !== MessageRole.SYSTEM && b.role === MessageRole.SYSTEM) return 1;
return a.timestamp - b.timestamp;
});
return result;
}

View File

@@ -23,7 +23,7 @@ export {
} from './pdf-processing';
// File conversion utilities (depends on pdf-processing)
export { parseFilesToMessageExtras, type FileProcessingResult } from './convert-files-to-extra';
export { parseFilesToMessageExtras } from './convert-files-to-extra';
// File upload processing utilities (depends on pdf-processing, svg-to-png, webp-to-png)
export { processFilesToChatUploaded } from './process-uploaded-files';

View File

@@ -0,0 +1,293 @@
const DEFAULT_CACHE_TTL_MS = 5 * 60 * 1000;
const DEFAULT_CACHE_MAX_ENTRIES = 100;
/**
* TTL Cache - Time-To-Live cache implementation for memory optimization
*
* Provides automatic expiration of cached entries to prevent memory bloat
* in long-running sessions.
*
* @example
* ```ts
* const cache = new TTLCache<string, ApiData>({ ttlMs: 5 * 60 * 1000 }); // 5 minutes
* cache.set('key', data);
* const value = cache.get('key'); // null if expired
* ```
*/
export interface TTLCacheOptions {
/** Time-to-live in milliseconds. Default: 5 minutes */
ttlMs?: number;
/** Maximum number of entries. Oldest entries are evicted when exceeded. Default: 100 */
maxEntries?: number;
/** Callback when an entry expires or is evicted */
onEvict?: (key: string, value: unknown) => void;
}
interface CacheEntry<T> {
value: T;
expiresAt: number;
lastAccessed: number;
}
export class TTLCache<K extends string, V> {
private cache = new Map<K, CacheEntry<V>>();
private readonly ttlMs: number;
private readonly maxEntries: number;
private readonly onEvict?: (key: string, value: unknown) => void;
constructor(options: TTLCacheOptions = {}) {
this.ttlMs = options.ttlMs ?? DEFAULT_CACHE_TTL_MS;
this.maxEntries = options.maxEntries ?? DEFAULT_CACHE_MAX_ENTRIES;
this.onEvict = options.onEvict;
}
/**
* Get a value from cache. Returns null if expired or not found.
*/
get(key: K): V | null {
const entry = this.cache.get(key);
if (!entry) return null;
if (Date.now() > entry.expiresAt) {
this.delete(key);
return null;
}
// Update last accessed time for LRU-like behavior
entry.lastAccessed = Date.now();
return entry.value;
}
/**
* Set a value in cache with TTL.
*/
set(key: K, value: V, customTtlMs?: number): void {
// Evict oldest entries if at capacity
if (this.cache.size >= this.maxEntries && !this.cache.has(key)) {
this.evictOldest();
}
const ttl = customTtlMs ?? this.ttlMs;
const now = Date.now();
this.cache.set(key, {
value,
expiresAt: now + ttl,
lastAccessed: now
});
}
/**
* Check if key exists and is not expired.
*/
has(key: K): boolean {
const entry = this.cache.get(key);
if (!entry) return false;
if (Date.now() > entry.expiresAt) {
this.delete(key);
return false;
}
return true;
}
/**
* Delete a specific key from cache.
*/
delete(key: K): boolean {
const entry = this.cache.get(key);
if (entry && this.onEvict) {
this.onEvict(key, entry.value);
}
return this.cache.delete(key);
}
/**
* Clear all entries from cache.
*/
clear(): void {
if (this.onEvict) {
for (const [key, entry] of this.cache) {
this.onEvict(key, entry.value);
}
}
this.cache.clear();
}
/**
* Get the number of entries (including potentially expired ones).
*/
get size(): number {
return this.cache.size;
}
/**
* Remove all expired entries from cache.
* Call periodically for proactive cleanup.
*/
prune(): number {
const now = Date.now();
let pruned = 0;
for (const [key, entry] of this.cache) {
if (now > entry.expiresAt) {
this.delete(key);
pruned++;
}
}
return pruned;
}
/**
* Get all valid (non-expired) keys.
*/
keys(): K[] {
const now = Date.now();
const validKeys: K[] = [];
for (const [key, entry] of this.cache) {
if (now <= entry.expiresAt) {
validKeys.push(key);
}
}
return validKeys;
}
/**
* Evict the oldest (least recently accessed) entry.
*/
private evictOldest(): void {
let oldestKey: K | null = null;
let oldestTime = Infinity;
for (const [key, entry] of this.cache) {
if (entry.lastAccessed < oldestTime) {
oldestTime = entry.lastAccessed;
oldestKey = key;
}
}
if (oldestKey !== null) {
this.delete(oldestKey);
}
}
/**
* Refresh TTL for an existing entry without changing the value.
*/
touch(key: K): boolean {
const entry = this.cache.get(key);
if (!entry) return false;
const now = Date.now();
if (now > entry.expiresAt) {
this.delete(key);
return false;
}
entry.expiresAt = now + this.ttlMs;
entry.lastAccessed = now;
return true;
}
}
/**
* Reactive TTL Map for Svelte stores
* Wraps SvelteMap with TTL functionality
*/
export class ReactiveTTLMap<K extends string, V> {
private entries = $state<Map<K, CacheEntry<V>>>(new Map());
private readonly ttlMs: number;
private readonly maxEntries: number;
constructor(options: TTLCacheOptions = {}) {
this.ttlMs = options.ttlMs ?? DEFAULT_CACHE_TTL_MS;
this.maxEntries = options.maxEntries ?? DEFAULT_CACHE_MAX_ENTRIES;
}
get(key: K): V | null {
const entry = this.entries.get(key);
if (!entry) return null;
if (Date.now() > entry.expiresAt) {
this.entries.delete(key);
return null;
}
entry.lastAccessed = Date.now();
return entry.value;
}
set(key: K, value: V, customTtlMs?: number): void {
if (this.entries.size >= this.maxEntries && !this.entries.has(key)) {
this.evictOldest();
}
const ttl = customTtlMs ?? this.ttlMs;
const now = Date.now();
this.entries.set(key, {
value,
expiresAt: now + ttl,
lastAccessed: now
});
}
has(key: K): boolean {
const entry = this.entries.get(key);
if (!entry) return false;
if (Date.now() > entry.expiresAt) {
this.entries.delete(key);
return false;
}
return true;
}
delete(key: K): boolean {
return this.entries.delete(key);
}
clear(): void {
this.entries.clear();
}
get size(): number {
return this.entries.size;
}
prune(): number {
const now = Date.now();
let pruned = 0;
for (const [key, entry] of this.entries) {
if (now > entry.expiresAt) {
this.entries.delete(key);
pruned++;
}
}
return pruned;
}
private evictOldest(): void {
let oldestKey: K | null = null;
let oldestTime = Infinity;
for (const [key, entry] of this.entries) {
if (entry.lastAccessed < oldestTime) {
oldestTime = entry.lastAccessed;
oldestKey = key;
}
}
if (oldestKey !== null) {
this.entries.delete(oldestKey);
}
}
}

View File

@@ -0,0 +1,85 @@
import hljs from 'highlight.js';
import {
NEWLINE,
DEFAULT_LANGUAGE,
LANG_PATTERN,
AMPERSAND_REGEX,
LT_REGEX,
GT_REGEX,
FENCE_PATTERN
} from '$lib/constants/code';
export interface IncompleteCodeBlock {
language: string;
code: string;
openingIndex: number;
}
/**
* Highlights code using highlight.js
* @param code - The code to highlight
* @param language - The programming language
* @returns HTML string with syntax highlighting
*/
export function highlightCode(code: string, language: string): string {
if (!code) return '';
try {
const lang = language.toLowerCase();
const isSupported = hljs.getLanguage(lang);
if (isSupported) {
return hljs.highlight(code, { language: lang }).value;
} else {
return hljs.highlightAuto(code).value;
}
} catch {
// Fallback to escaped plain text
return code
.replace(AMPERSAND_REGEX, '&amp;')
.replace(LT_REGEX, '&lt;')
.replace(GT_REGEX, '&gt;');
}
}
/**
* Detects if markdown ends with an incomplete code block (opened but not closed).
* Returns the code block info if found, null otherwise.
* @param markdown - The raw markdown string to check
* @returns IncompleteCodeBlock info or null
*/
export function detectIncompleteCodeBlock(markdown: string): IncompleteCodeBlock | null {
// Count all code fences in the markdown
// A code block is incomplete if there's an odd number of ``` fences
const fencePattern = new RegExp(FENCE_PATTERN.source, FENCE_PATTERN.flags);
const fences: number[] = [];
let fenceMatch;
while ((fenceMatch = fencePattern.exec(markdown)) !== null) {
// Store the position after the ```
const pos = fenceMatch[0].startsWith(NEWLINE) ? fenceMatch.index + 1 : fenceMatch.index;
fences.push(pos);
}
// If even number of fences (including 0), all code blocks are closed
if (fences.length % 2 === 0) {
return null;
}
// Odd number means last code block is incomplete
// The last fence is the opening of the incomplete block
const openingIndex = fences[fences.length - 1];
const afterOpening = markdown.slice(openingIndex + 3);
// Extract language and code content
const langMatch = afterOpening.match(LANG_PATTERN);
const language = langMatch?.[1] || DEFAULT_LANGUAGE;
const codeStartIndex = openingIndex + 3 + (langMatch?.[0]?.length ?? 0);
const code = markdown.slice(codeStartIndex);
return {
language,
code,
openingIndex
};
}

View File

@@ -0,0 +1,10 @@
/**
* Creates a base64 data URL from MIME type and base64-encoded data.
*
* @param mimeType - The MIME type (e.g., 'image/png', 'audio/mp3')
* @param base64Data - The base64-encoded data
* @returns A data URL string in format 'data:{mimeType};base64,{data}'
*/
export function createBase64DataUrl(mimeType: string, base64Data: string): string {
return `data:${mimeType};base64,${base64Data}`;
}

View File

@@ -0,0 +1,22 @@
/**
* @param fn - The function to debounce
* @param delay - The delay in milliseconds
* @returns A debounced version of the function
*/
export function debounce<T extends (...args: Parameters<T>) => void>(
fn: T,
delay: number
): (...args: Parameters<T>) => void {
let timeoutId: ReturnType<typeof setTimeout> | null = null;
return (...args: Parameters<T>) => {
if (timeoutId) {
clearTimeout(timeoutId);
}
timeoutId = setTimeout(() => {
fn(...args);
timeoutId = null;
}, delay);
};
}

Some files were not shown because too many files have changed in this diff Show More