mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-19 14:13:22 +02:00
Compare commits
64 Commits
b8049
...
gg/scripts
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c0c3e428dd | ||
|
|
7f049860b4 | ||
|
|
2ffa45edfc | ||
|
|
9c29be1177 | ||
|
|
013963cfd5 | ||
|
|
e2e998a2d6 | ||
|
|
6c41664b8b | ||
|
|
7b84af8051 | ||
|
|
60a501e138 | ||
|
|
e6e777cfb3 | ||
|
|
ad3a54eb68 | ||
|
|
c6d70b9bea | ||
|
|
de956a6ca8 | ||
|
|
350e7c1409 | ||
|
|
db10dda1f3 | ||
|
|
52759bf078 | ||
|
|
99e3c3d02c | ||
|
|
c6315655b7 | ||
|
|
f762a71d56 | ||
|
|
73e61d5b75 | ||
|
|
cffd268bb3 | ||
|
|
e8a807519a | ||
|
|
1db8428f00 | ||
|
|
7751ae2796 | ||
|
|
d2b10302ce | ||
|
|
68dde884d6 | ||
|
|
fd90796da2 | ||
|
|
8156d549f6 | ||
|
|
9695e6feb4 | ||
|
|
fb1481d60d | ||
|
|
812ae13ec1 | ||
|
|
e79e8d02d5 | ||
|
|
a939f4c47e | ||
|
|
62b04cef54 | ||
|
|
37b26cafee | ||
|
|
04f6872116 | ||
|
|
c2619c18bf | ||
|
|
87f8930968 | ||
|
|
9453f9de12 | ||
|
|
5a1be6ce37 | ||
|
|
a80814e97b | ||
|
|
5cc2258e82 | ||
|
|
c87af1d527 | ||
|
|
23d4e21a81 | ||
|
|
07d5e1e0ea | ||
|
|
8839037528 | ||
|
|
89cab3dbc5 | ||
|
|
c2d83ca048 | ||
|
|
c05df17ce3 | ||
|
|
27b93cbd15 | ||
|
|
6e67fd2144 | ||
|
|
9e118b97c4 | ||
|
|
57088276d4 | ||
|
|
341bc7d23c | ||
|
|
08e6d914b8 | ||
|
|
184c694f45 | ||
|
|
684b36101c | ||
|
|
3a00c98584 | ||
|
|
079feab9e3 | ||
|
|
01d8eaa28d | ||
|
|
1725e316c1 | ||
|
|
b7742cf321 | ||
|
|
badba89320 | ||
|
|
baa12f3831 |
@@ -112,7 +112,6 @@ option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_
|
||||
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)
|
||||
|
||||
# 3rd party libs
|
||||
option(LLAMA_HTTPLIB "llama: httplib for downloading functionality" ON)
|
||||
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" ON)
|
||||
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
|
||||
|
||||
@@ -197,9 +196,7 @@ add_subdirectory(src)
|
||||
|
||||
if (LLAMA_BUILD_COMMON)
|
||||
add_subdirectory(common)
|
||||
if (LLAMA_HTTPLIB)
|
||||
add_subdirectory(vendor/cpp-httplib)
|
||||
endif()
|
||||
add_subdirectory(vendor/cpp-httplib)
|
||||
endif()
|
||||
|
||||
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
||||
|
||||
@@ -449,10 +449,9 @@ cmake -B build-visionos -G Xcode \
|
||||
-DCMAKE_SYSTEM_NAME=visionOS \
|
||||
-DCMAKE_OSX_SYSROOT=xros \
|
||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DLLAMA_HTTPLIB=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos --config Release -- -quiet
|
||||
@@ -465,10 +464,9 @@ cmake -B build-visionos-sim -G Xcode \
|
||||
-DCMAKE_SYSTEM_NAME=visionOS \
|
||||
-DCMAKE_OSX_SYSROOT=xrsimulator \
|
||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DLLAMA_HTTPLIB=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos-sim --config Release -- -quiet
|
||||
|
||||
@@ -112,11 +112,7 @@ endif()
|
||||
|
||||
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS build_info)
|
||||
|
||||
if (LLAMA_HTTPLIB)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
|
||||
endif()
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
|
||||
|
||||
if (LLAMA_LLGUIDANCE)
|
||||
include(ExternalProject)
|
||||
|
||||
@@ -879,7 +879,8 @@ std::string fs_get_cache_directory() {
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \
|
||||
defined(__OpenBSD__) || defined(__NetBSD__)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else if (std::getenv("HOME")) {
|
||||
|
||||
@@ -19,9 +19,7 @@
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#if defined(LLAMA_USE_HTTPLIB)
|
||||
#include "http.h"
|
||||
#endif
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
#ifdef __linux__
|
||||
@@ -142,8 +140,6 @@ std::pair<std::string, std::string> common_download_split_repo_tag(const std::st
|
||||
return {hf_repo, tag};
|
||||
}
|
||||
|
||||
#if defined(LLAMA_USE_HTTPLIB)
|
||||
|
||||
class ProgressBar {
|
||||
static inline std::mutex mutex;
|
||||
static inline std::map<const ProgressBar *, int> lines;
|
||||
@@ -768,30 +764,6 @@ std::string common_docker_resolve_model(const std::string & docker) {
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
std::string common_docker_resolve_model(const std::string &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
int common_download_file_single(const std::string &,
|
||||
const std::string &,
|
||||
const std::string &,
|
||||
bool,
|
||||
const common_header_list &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
#endif // defined(LLAMA_USE_HTTPLIB)
|
||||
|
||||
std::vector<common_cached_model_info> common_list_cached_models() {
|
||||
std::vector<common_cached_model_info> models;
|
||||
const std::string cache_dir = fs_get_cache_directory();
|
||||
|
||||
@@ -2726,8 +2726,6 @@ class AfmoeModel(LlamaModel):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# MoE parameters
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
|
||||
self.gguf_writer.add_expert_shared_count(n_shared_experts)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
@@ -2749,7 +2747,7 @@ class AfmoeModel(LlamaModel):
|
||||
# Handle expert weights - they're already merged in the HF format
|
||||
# process the experts separately
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -4074,6 +4072,87 @@ class InternVisionModel(MmprojModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register(
|
||||
"NemotronH_Nano_VL_V2",
|
||||
"RADIOModel",
|
||||
)
|
||||
class NemotronNanoV2VLModel(MmprojModel):
|
||||
# ViT-Huge architecture parameters for RADIO v2.5-h
|
||||
_vit_hidden_size = 1280
|
||||
_vit_intermediate_size = 5120
|
||||
_vit_num_layers = 32
|
||||
_vit_num_heads = 16
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
# RADIO config doesn't have standard ViT parameters, so they need to be constructed manually
|
||||
vision_config = self.global_config.get("vision_config")
|
||||
if vision_config is None:
|
||||
return None
|
||||
# Add ViT-H parameters
|
||||
vision_config = {
|
||||
**vision_config,
|
||||
"hidden_size": self._vit_hidden_size,
|
||||
"intermediate_size": self._vit_intermediate_size,
|
||||
"num_hidden_layers": self._vit_num_layers,
|
||||
"num_attention_heads": self._vit_num_heads,
|
||||
"image_size": self.global_config.get("force_image_size", 512),
|
||||
}
|
||||
return vision_config
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
if "image_mean" not in self.preprocessor_config:
|
||||
self.preprocessor_config["image_mean"] = [0.485, 0.456, 0.406]
|
||||
if "image_std" not in self.preprocessor_config:
|
||||
self.preprocessor_config["image_std"] = [0.229, 0.224, 0.225]
|
||||
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.global_config
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.NEMOTRON_V2_VL)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(1e-6)
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
downsample_ratio = hparams.get("downsample_ratio", 0.5)
|
||||
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio))
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
if ".position_embd." in new_name or "pos_embed" in new_name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "input_conditioner" in name:
|
||||
return
|
||||
|
||||
# RADIO's pos_embed doesn't have .weight suffix, but clip.cpp expects it
|
||||
if "patch_generator.pos_embed" in name:
|
||||
if not name.endswith(".weight"):
|
||||
name += ".weight"
|
||||
# Downsample position embeddings for fixed 512x512 image size
|
||||
import torch.nn.functional as F
|
||||
n_embd = self.hparams["hidden_size"]
|
||||
image_size = self.global_config.get("force_image_size", 512)
|
||||
patch_size = self.hparams["patch_size"]
|
||||
target_patches_per_side = image_size // patch_size # 32
|
||||
max_patches_per_side = int((data_torch.shape[1]) ** 0.5) # 128
|
||||
if target_patches_per_side != max_patches_per_side:
|
||||
# Reshape to grid, interpolate, flatten back
|
||||
data_torch = data_torch.reshape(1, max_patches_per_side, max_patches_per_side, n_embd)
|
||||
data_torch = data_torch.permute(0, 3, 1, 2).float() # [1, n_embd, 128, 128]
|
||||
data_torch = F.interpolate(data_torch, size=(target_patches_per_side, target_patches_per_side),
|
||||
mode='bilinear', align_corners=True)
|
||||
data_torch = data_torch.permute(0, 2, 3, 1) # [1, 32, 32, n_embd]
|
||||
data_torch = data_torch.reshape(1, target_patches_per_side * target_patches_per_side, n_embd)
|
||||
|
||||
# Reshape linear patch embedding to conv2d format for ggml_conv_2d
|
||||
# From [n_embd, patch_size*patch_size*3] to [n_embd, 3, patch_size, patch_size]
|
||||
if "patch_generator.embedder" in name:
|
||||
patch_size = self.hparams["patch_size"]
|
||||
n_embd = self.hparams["hidden_size"]
|
||||
data_torch = data_torch.reshape(n_embd, 3, patch_size, patch_size)
|
||||
|
||||
if name.startswith("vision_model.radio_model.model.") or name.startswith("mlp1."):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("WavTokenizerDec")
|
||||
class WavTokenizerDecModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
|
||||
@@ -4116,8 +4195,6 @@ class Qwen2MoeModel(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
|
||||
@@ -4162,7 +4239,7 @@ class Qwen2MoeModel(TextModel):
|
||||
return
|
||||
|
||||
if name.find("experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -4913,13 +4990,13 @@ class PhiMoeModel(Phi3MiniModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_local_experts"])
|
||||
self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok", "num_experts_per_token"]))
|
||||
self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts", "num_experts"]))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("block_sparse_moe.experts") != -1:
|
||||
n_experts = self.hparams["num_local_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -5331,7 +5408,7 @@ class KimiLinearModel(TextModel):
|
||||
|
||||
# process the experts separately
|
||||
if name.find("block_sparse_moe.experts") != -1:
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=False)
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -5926,12 +6003,13 @@ class NomicBertModel(BertModel):
|
||||
if "mlp.experts.bias" in name:
|
||||
return # Explicitly return.
|
||||
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
if "mlp.experts.mlp.w1" in name:
|
||||
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
|
||||
data_torch = data_torch.view(n_experts, self.hparams["n_inner"], self.hparams["n_embd"])
|
||||
name += ".weight"
|
||||
|
||||
if "mlp.experts.mlp.w2" in name:
|
||||
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
|
||||
data_torch = data_torch.view(n_experts, self.hparams["n_inner"], self.hparams["n_embd"])
|
||||
data_torch = data_torch.transpose(1, 2)
|
||||
name += ".weight"
|
||||
|
||||
@@ -5941,7 +6019,6 @@ class NomicBertModel(BertModel):
|
||||
super().set_gguf_parameters()
|
||||
if self.is_moe:
|
||||
self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"])
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"])
|
||||
|
||||
def _is_tokenizer_xlmroberta(self) -> bool:
|
||||
@@ -7055,6 +7132,8 @@ class Mamba2Model(TextModel):
|
||||
if hparams is None:
|
||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
hparams = json.load(f)
|
||||
if "llm_config" in hparams:
|
||||
hparams["text_config"] = hparams["llm_config"]
|
||||
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
|
||||
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
|
||||
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
|
||||
@@ -7176,8 +7255,8 @@ class JambaModel(TextModel):
|
||||
self.gguf_writer.add_ssm_state_size(d_state)
|
||||
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
|
||||
self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts", "num_experts"]))
|
||||
self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok", "num_experts_per_token"]))
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
@@ -7195,7 +7274,7 @@ class JambaModel(TextModel):
|
||||
|
||||
# process the experts separately
|
||||
if ".feed_forward.experts." in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
|
||||
assert bid is not None
|
||||
|
||||
@@ -7343,8 +7422,6 @@ class OlmoeModel(TextModel):
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_layer_norm_rms_eps(1e-5)
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
@@ -7352,7 +7429,7 @@ class OlmoeModel(TextModel):
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -7933,10 +8010,6 @@ class MiniMaxM2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.MINIMAXM2
|
||||
_experts_cache: dict[int, dict[str, Tensor]] = {}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hparams["num_experts"] = self.hparams["num_local_experts"]
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
@@ -7949,7 +8022,7 @@ class MiniMaxM2Model(TextModel):
|
||||
|
||||
# merge expert weights
|
||||
if 'experts' in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
expert_cache = self._experts_cache.setdefault(bid, {})
|
||||
@@ -9154,7 +9227,6 @@ class ExaoneMoEModel(Exaone4Model):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
|
||||
moe_intermediate_size = self.hparams["moe_intermediate_size"]
|
||||
num_shared_experts = self.hparams["num_shared_experts"]
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
@@ -9195,7 +9267,7 @@ class ExaoneMoEModel(Exaone4Model):
|
||||
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
|
||||
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -9346,7 +9418,7 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
|
||||
# case, the model architecture needs to be updated to a standard
|
||||
# "granite" or "granitemoe" model
|
||||
if not self._ssm_layers:
|
||||
has_experts = self.find_hparam(["num_experts_per_tok"], optional=True)
|
||||
has_experts = self.find_hparam(["num_experts_per_tok", "num_experts_per_token"], optional=True)
|
||||
new_arch = (
|
||||
gguf.MODEL_ARCH.GRANITE_MOE
|
||||
if has_experts else
|
||||
@@ -9542,6 +9614,14 @@ class NemotronHModel(GraniteHybridModel):
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Skip vision model and projector tensors for VLM models (handled by mmproj) (e.g., Nemotron Nano 12B v2 VL)
|
||||
if name.startswith(("vision_model.", "mlp1.")):
|
||||
return
|
||||
|
||||
# Strip language_model. prefix for VLM models (e.g., Nemotron Nano 12B v2 VL)
|
||||
if name.startswith("language_model."):
|
||||
name = name[len("language_model."):]
|
||||
|
||||
if self.is_moe and bid is not None:
|
||||
if name.endswith("mixer.gate.e_score_correction_bias"):
|
||||
new_name = name.replace("e_score_correction_bias", "e_score_correction.bias")
|
||||
@@ -9636,7 +9716,6 @@ class BailingMoeModel(TextModel):
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_expert_weights_scale(1.0)
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
@@ -9670,7 +9749,7 @@ class BailingMoeModel(TextModel):
|
||||
yield from super().modify_tensors(v,self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), bid)
|
||||
return
|
||||
elif name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -9741,7 +9820,6 @@ class BailingMoeV2Model(TextModel):
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"]))
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
@@ -9752,7 +9830,7 @@ class BailingMoeV2Model(TextModel):
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "mlp.experts" in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -9798,8 +9876,6 @@ class GroveMoeModel(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
|
||||
@@ -9820,7 +9896,7 @@ class GroveMoeModel(TextModel):
|
||||
|
||||
# process the experts separately
|
||||
if name.find("chunk_experts") != -1:
|
||||
n_experts = self.hparams["num_experts"] // 2 # see add_experts_per_group
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"]) // 2 # see add_experts_per_group
|
||||
assert bid is not None
|
||||
|
||||
if self._chunk_experts is None:
|
||||
@@ -9847,7 +9923,7 @@ class GroveMoeModel(TextModel):
|
||||
else:
|
||||
return
|
||||
elif name.find("experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -10240,7 +10316,6 @@ class HunYuanMoEModel(TextModel):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])
|
||||
|
||||
moe_intermediate_size = hparams["moe_intermediate_size"]
|
||||
@@ -10283,7 +10358,7 @@ class HunYuanMoEModel(TextModel):
|
||||
return
|
||||
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -10325,16 +10400,9 @@ class LLaDAMoEModel(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
|
||||
if (expert_intermediate_size := self.hparams.get("expert_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size)
|
||||
|
||||
# number of experts used per token (top-k)
|
||||
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
|
||||
self.gguf_writer.add_mask_token_id(156895)
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
self.gguf_writer.add_diffusion_shift_logits(False)
|
||||
@@ -10345,7 +10413,7 @@ class LLaDAMoEModel(TextModel):
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -10682,7 +10750,6 @@ class LFM2MoeModel(TextModel):
|
||||
|
||||
super().set_gguf_parameters()
|
||||
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"])
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
@@ -10703,7 +10770,7 @@ class LFM2MoeModel(TextModel):
|
||||
|
||||
# merge expert weights
|
||||
if 'experts' in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
expert_cache = self._experts_cache.setdefault(bid, {})
|
||||
@@ -10813,9 +10880,9 @@ class SmallThinkerModel(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None:
|
||||
if (n_experts := self.hparams.get("moe_num_primary_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None:
|
||||
if (n_experts_used := self.hparams.get("moe_num_active_primary_experts")) is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
@@ -10840,7 +10907,7 @@ class SmallThinkerModel(TextModel):
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("experts") != -1:
|
||||
n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))
|
||||
n_experts = self.hparams.get("moe_num_primary_experts") or self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
|
||||
@@ -242,10 +242,10 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
|------------|-------------|------|-------|
|
||||
| FP32 | ✅ | ✅ | ❓ |
|
||||
| FP16 | ✅ | ✅ | ❓ |
|
||||
| BF16 | 🚫 | ✅ | ❓ |
|
||||
| BF16 | ✅ | ✅ | ❓ |
|
||||
| Q4_0 | ✅ | ❓ | ❓ |
|
||||
| Q4_1 | ✅ | ❓ | ❓ |
|
||||
| MXFP4 | 🚫 | ❓ | ❓ |
|
||||
| MXFP4 | ✅ | ❓ | ❓ |
|
||||
| Q5_0 | ✅ | ❓ | ❓ |
|
||||
| Q5_1 | ✅ | ❓ | ❓ |
|
||||
| Q8_0 | ✅ | ❓ | ❓ |
|
||||
@@ -272,4 +272,4 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
- 🚫 - acceleration unavailable, will still run using scalar implementation
|
||||
- ❓ - acceleration unknown, please contribute if you can test it yourself
|
||||
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 7, 2025.
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Feb 15, 2026.
|
||||
|
||||
190
examples/llama-eval/AGENTS.md
Normal file
190
examples/llama-eval/AGENTS.md
Normal file
@@ -0,0 +1,190 @@
|
||||
# llama-eval Codebase Guidelines
|
||||
|
||||
## Overview
|
||||
|
||||
This directory contains Python evaluation tools for llama.cpp:
|
||||
- `llama-eval.py` - Main evaluation tool with multiple datasets (AIME, AIME2025, GSM8K, GPQA)
|
||||
- `llama-server-simulator.py` - Flask-based server simulator for testing
|
||||
- `test-simulator.sh` - Test script for the simulator
|
||||
|
||||
## Build/Run Commands
|
||||
|
||||
### Virtual Environment
|
||||
The project uses a virtual environment located at `venv/`:
|
||||
```bash
|
||||
source venv/bin/activate
|
||||
```
|
||||
|
||||
### Running the Main Evaluator
|
||||
```bash
|
||||
python llama-eval.py \
|
||||
--server http://127.0.0.1:8013 \
|
||||
--model gpt-oss-20b-hf-low \
|
||||
--dataset aime \
|
||||
--n_cases 10 \
|
||||
--grader-type llm \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
### Running the Simulator (for testing)
|
||||
```bash
|
||||
python llama-server-simulator.py --port 8033 --success-rate 0.8
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
```bash
|
||||
./test-simulator.sh
|
||||
```
|
||||
|
||||
## Code Style Guidelines
|
||||
|
||||
### Imports
|
||||
- Standard library imports first (argparse, json, os, re, subprocess, sys, time)
|
||||
- Third-party imports (requests, tqdm, datasets, flask) after standard library
|
||||
- Relative imports not used
|
||||
- Group imports by category with blank line between groups
|
||||
|
||||
### Formatting
|
||||
- 4-space indentation
|
||||
- Max line length: 125 characters (per parent project's .flake8)
|
||||
- Use double quotes for strings
|
||||
- Use triple double quotes for docstrings
|
||||
- Binary operators at the beginning of continued lines
|
||||
|
||||
### Naming Conventions
|
||||
- Classes: PascalCase (e.g., `AimeDataset`, `Grader`, `Processor`)
|
||||
- Functions: snake_case (e.g., `normalize_number`, `get_prompt`)
|
||||
- Variables: snake_case (e.g., `question_text`, `correct_count`)
|
||||
- Constants: UPPER_SNAKE_CASE (e.g., `GRADER_PATTERNS`, `TEMPLATE_REGISTRY`)
|
||||
- Private methods: prefix with underscore (e.g., `_load_dataset`, `_grade_regex`)
|
||||
|
||||
### Types
|
||||
- Use type hints for all function signatures
|
||||
- Import from `typing` module: `Dict`, `List`, `Optional`, `Any`, `Tuple`
|
||||
- Use `@dataclass` for data structures
|
||||
- Prefer `Optional[T]` over `Union[T, None]`
|
||||
|
||||
### Error Handling
|
||||
- Use try/except for network requests and file operations
|
||||
- Return `None` or `False` on errors when appropriate
|
||||
- Use `ValueError` for invalid arguments
|
||||
- Use `FileNotFoundError` for missing files
|
||||
- CLI scripts should handle exceptions gracefully
|
||||
|
||||
### Dataclasses
|
||||
- Use `@dataclass` for structured data
|
||||
- Define fields with explicit types
|
||||
- Use `Optional[T]` for nullable fields
|
||||
- Provide default values where appropriate
|
||||
|
||||
### String Formatting
|
||||
- Use f-strings for formatting (Python 3.6+)
|
||||
- Use triple double quotes for multi-line strings
|
||||
- Escape backslashes in regex patterns: `r'\\boxed{(\d+)}'`
|
||||
|
||||
### File Paths
|
||||
- Use `pathlib.Path` instead of string paths
|
||||
- Create directories with `mkdir(parents=True, exist_ok=True)`
|
||||
- Use `Path.home()` for user home directory
|
||||
|
||||
### Logging
|
||||
- Use `print()` for user-facing output
|
||||
- Use `sys.stderr` for debug logging
|
||||
- Simulator writes debug logs to `/tmp/simulator-debug.log`
|
||||
|
||||
### Testing
|
||||
|
||||
- Test script uses bash with `set -e` for strict error handling
|
||||
- Simulator runs in background with PID tracking
|
||||
- Tests verify correct answers, error cases, and edge cases
|
||||
- Use `curl` for HTTP testing in shell scripts
|
||||
|
||||
### Whitespace Cleanup
|
||||
- Remove trailing whitespace from all lines
|
||||
- When making edits, do not leave trailing whitespace
|
||||
|
||||
## Dataset Support
|
||||
|
||||
### AIME Dataset
|
||||
- 90 questions from 2025 AIME competition
|
||||
- Answers in `\boxed{answer}` format
|
||||
- Supports regex, CLI, and LLM grading
|
||||
|
||||
### AIME2025 Dataset
|
||||
- 30 questions from 2025 AIME I & II
|
||||
- Answers in `\boxed{answer}` format
|
||||
- Requires loading two config parts
|
||||
|
||||
### GSM8K Dataset
|
||||
- 7473 math word problems
|
||||
- Answers numeric values with `####` separator
|
||||
- Supports regex, CLI, and LLM grading
|
||||
|
||||
### GPQA Dataset
|
||||
- 198 questions from GPQA Diamond
|
||||
- Multiple choice with shuffled options (A, B, C, D)
|
||||
- **Requires LLM grader** (returns letter A/B/C/D)
|
||||
|
||||
## Grading Types
|
||||
|
||||
### Regex Grader
|
||||
- Built-in patterns per dataset
|
||||
- Prioritizes `\boxed{}` for AIME datasets
|
||||
- Extracts last number for GSM8K
|
||||
|
||||
### CLI Grader
|
||||
- External script interface
|
||||
- Call: `grader.sh --answer <pred> --expected <gold>`
|
||||
- Exit code 0 = correct, non-zero = incorrect
|
||||
|
||||
### LLM Grader
|
||||
- Uses judge model for answer extraction
|
||||
- Includes few-shot examples
|
||||
- Case-insensitive comparison
|
||||
- Required for GPQA
|
||||
|
||||
## Configuration
|
||||
|
||||
### Sampling Parameters (Optional)
|
||||
- `--temperature`: Sampling temperature
|
||||
- `--top-k`: Top K sampling
|
||||
- `--top-p`: Top P sampling
|
||||
- `--min-p`: Min P sampling
|
||||
- Only passed to API if explicitly specified
|
||||
|
||||
### Default Values
|
||||
- `--n_predict`: -1 (infinite)
|
||||
- `--grader-type`: llm
|
||||
- `--seed`: 1234
|
||||
- `--threads`: 32
|
||||
- `--output`: llama-eval-state.json
|
||||
|
||||
## Output Format
|
||||
|
||||
### Progress Table
|
||||
- Shows task ID, dataset, prompt (truncated to 43 chars), expected answer, status
|
||||
- Uses `tqdm` for progress bars
|
||||
|
||||
### Results Summary
|
||||
- Format: `Results: X/Y correct (Z%)`
|
||||
- Displayed after all tasks complete
|
||||
|
||||
### JSON Output
|
||||
- Complete eval state saved to output file
|
||||
- Contains: task IDs, correctness, prompts, extracted answers, sampling config
|
||||
- Uses `dataclasses.asdict()` for serialization
|
||||
|
||||
## HuggingFace Datasets
|
||||
|
||||
- Cache directory: `~/.cache/huggingface/datasets`
|
||||
- Set via `HF_DATASETS_CACHE` environment variable
|
||||
- Telemetry disabled via `HF_HUB_DISABLE_TELEMETRY=1`
|
||||
- Datasets loaded with `datasets.load_dataset()`
|
||||
|
||||
## Flask Simulator
|
||||
|
||||
- Runs on configurable port (default: 5000)
|
||||
- Endpoint: `/v1/chat/completions` (OpenAI-compatible)
|
||||
- Uses Dice coefficient for question matching
|
||||
- Configurable success rate for testing
|
||||
- Debug logs to `/tmp/simulator-debug.log`
|
||||
94
examples/llama-eval/IMPLEMENTATION.md
Normal file
94
examples/llama-eval/IMPLEMENTATION.md
Normal file
@@ -0,0 +1,94 @@
|
||||
# llama-eval Implementation Summary
|
||||
|
||||
## Overview
|
||||
|
||||
Simple evaluation tool for llama.cpp with support for multiple datasets (AIME, GSM8K, GPQA) and flexible grading (regex, CLI, LLM).
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Multiple Datasets**: AIME, GSM8K, GPQA with proper answer extraction
|
||||
- **Flexible Grading**: Regex, CLI, or LLM-based grading
|
||||
- **Parallel Processing**: Configurable thread count for concurrent requests
|
||||
- **Sampling Parameters**: Temperature, Top K, Top P, Min P (optional)
|
||||
- **Real-time Feedback**: Progress tracking with detailed output
|
||||
- **JSON Output**: Complete eval state saved for debugging
|
||||
- **GPQA Support**: Answer shuffling with reproducible results
|
||||
|
||||
## Architecture
|
||||
|
||||
### Eval State
|
||||
```python
|
||||
@dataclass
|
||||
class EvalState:
|
||||
id: str
|
||||
tasks: List[str]
|
||||
task_states: Dict[str, Dict[str, Any]]
|
||||
sampling_config: Dict[str, Any]
|
||||
```
|
||||
|
||||
### Processor
|
||||
- Handles processing, grading, and state management
|
||||
- Thread-safe concurrent execution
|
||||
- Configurable sampling parameters
|
||||
|
||||
### Grader
|
||||
- Abstract grading interface supporting multiple types
|
||||
- Regex grader with dataset-specific patterns
|
||||
- CLI grader with external script interface
|
||||
- LLM grader with configurable server and model
|
||||
|
||||
### Datasets
|
||||
- `AimeDataset`: 90 AIME 2025 questions
|
||||
- `Aime2025Dataset`: 30 AIME 2025 I & II questions
|
||||
- `Gsm8kDataset`: 7473 math word problems
|
||||
- `GpqaDataset`: 198 GPQA Diamond questions with shuffling
|
||||
|
||||
## Configuration
|
||||
|
||||
### Sampling Parameters (Optional)
|
||||
- `--temperature`: Sampling temperature
|
||||
- `--top-k`: Top K sampling
|
||||
- `--top-p`: Top P sampling
|
||||
- `--min-p`: Min P sampling
|
||||
- Only passed if explicitly specified
|
||||
|
||||
### Grading Types
|
||||
- **regex**: Built-in patterns for each dataset
|
||||
- **cli**: External script with `--answer` and `--expected` args
|
||||
- **llm**: LLM-based extraction with few-shot examples and configurable server/model
|
||||
|
||||
### Dataset Requirements
|
||||
- **AIME**: Supports regex, CLI, or LLM grader
|
||||
- **AIME2025**: Supports regex, CLI, or LLM grader
|
||||
- **GSM8K**: Supports regex, CLI, or LLM grader
|
||||
- **GPQA**: Requires LLM grader
|
||||
|
||||
## Output Format
|
||||
|
||||
### Progress Table
|
||||
```
|
||||
Task ID Dataset Prompt (first 43 chars) Expected Status
|
||||
aime_000_001 AIME Complete the following reactions and sel... A pending
|
||||
```
|
||||
|
||||
### Results Summary
|
||||
```
|
||||
============================================================
|
||||
Results: 8/10 correct (80.0%)
|
||||
============================================================
|
||||
```
|
||||
|
||||
### JSON Output
|
||||
Complete eval state with task IDs, correctness, prompts, extracted answers, and sampling configuration.
|
||||
|
||||
## Technical Details
|
||||
|
||||
- Default max tokens: -1 (infinite)
|
||||
- Default grader type: llm
|
||||
- Default seed: 1234
|
||||
- Default threads: 32
|
||||
- Prompt truncation: First 43 chars + padding + "..."
|
||||
- Response truncation: Last 10 lines for grading
|
||||
- GPQA requires LLM grader (returns letter A/B/C/D)
|
||||
- Judge model defaults to evaluated model if not specified
|
||||
- Sample answers defined in SAMPLE_ANSWERS dict for few-shot learning
|
||||
112
examples/llama-eval/README.md
Normal file
112
examples/llama-eval/README.md
Normal file
@@ -0,0 +1,112 @@
|
||||
# llama-eval Evaluation Tool
|
||||
|
||||
Simple evaluation tool for llama.cpp with support for multiple datasets.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multiple Datasets**: AIME, GSM8K, GPQA
|
||||
- **Flexible Grading**: Regex, CLI, or LLM-based grading
|
||||
- **Parallel Processing**: Configurable thread count
|
||||
- **Real-time Feedback**: Progress tracking with detailed output
|
||||
- **Sampling Parameters**: Temperature, Top K, Top P, Min P
|
||||
- **JSON Output**: Complete eval state saved for debugging
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
python llama-eval.py \
|
||||
--server http://127.0.0.1:8013 \
|
||||
--model gpt-oss-20b-hf-low \
|
||||
--judge-model gpt-oss-20b-hf-medium \
|
||||
--dataset aime \
|
||||
--n_cases 10 \
|
||||
--grader-type llm \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
## CLI Arguments
|
||||
|
||||
- `--server`: llama-server URL (default: http://127.0.0.1:8013)
|
||||
- `--model`: Model name for evaluation (default: llama)
|
||||
- `--judge-model`: Model name for LLM judge (default: same as main model)
|
||||
- `--judge-server`: Server URL for LLM judge (default: same as main server)
|
||||
- `--dataset`: Dataset type (aime, aime2025, gsm8k, gpqa)
|
||||
- `--n_cases`: Number of cases to evaluate (default: all)
|
||||
- `--n_predict`: Max tokens to predict per prompt (default: -1, infinite)
|
||||
- `--temperature`: Sampling temperature (default: not passed)
|
||||
- `--top-k`: Top K sampling (default: not passed)
|
||||
- `--top-p`: Top P sampling (default: not passed)
|
||||
- `--min-p`: Min P sampling (default: not passed)
|
||||
- `--threads`: Number of threads for parallel requests (default: 32)
|
||||
- `--verbose`: Show detailed output for each case
|
||||
- `--output`: Output file for eval state (default: llama-eval-state.json)
|
||||
- `--grader-type`: Grader type (regex, cli, llm, default: llm)
|
||||
- `--grader-script`: Path to CLI grader script (required for --grader-type cli)
|
||||
- `--seed`: Random seed for shuffling (default: 1234)
|
||||
|
||||
## Datasets
|
||||
|
||||
### AIME
|
||||
- 90 questions from 2025 AIME competition
|
||||
- Answers in boxed format: `\boxed{answer}`
|
||||
- Requires regex grader or LLM grader
|
||||
|
||||
### AIME2025
|
||||
- 30 questions from 2025 AIME I & II competitions
|
||||
- Answers in boxed format: `\boxed{answer}`
|
||||
- Supports regex, CLI, or LLM grader
|
||||
|
||||
### GSM8K
|
||||
- 7473 math word problems
|
||||
- Answers are numeric values
|
||||
- Requires regex grader or LLM grader
|
||||
|
||||
### GPQA
|
||||
- 198 questions from GPQA Diamond dataset
|
||||
- Multiple choice with shuffled options
|
||||
- Requires LLM grader (returns letter A, B, C, or D)
|
||||
|
||||
## Grading Types
|
||||
|
||||
### Regex Grader
|
||||
Built-in patterns for different datasets:
|
||||
- AIME: `\boxed{(\d+)}|\b(\d+)\b`
|
||||
- AIME2025: `\boxed{(\d+)}|\b(\d+)\b`
|
||||
- GSM8K: `\b(\d+)\b`
|
||||
- GPQA: Letter extraction (A, B, C, D)
|
||||
|
||||
### CLI Grader
|
||||
External script interface:
|
||||
```bash
|
||||
./grader.sh --answer <pred> --expected <gold>
|
||||
```
|
||||
Returns exit code 0 if correct, non-zero if incorrect.
|
||||
|
||||
### LLM Grader
|
||||
Uses LLM to extract and compare answers:
|
||||
- Configurable server and model
|
||||
- Includes few-shot examples from sample answers
|
||||
- Case-insensitive comparison
|
||||
- Required for GPQA dataset
|
||||
|
||||
## Output
|
||||
|
||||
### Progress Table
|
||||
```
|
||||
Task ID Dataset Prompt (first 43 chars) Expected Status
|
||||
aime_000_001 AIME Complete the following reactions and sel... A pending
|
||||
```
|
||||
|
||||
### Results
|
||||
```
|
||||
============================================================
|
||||
Results: 8/10 correct (80.0%)
|
||||
============================================================
|
||||
```
|
||||
|
||||
### JSON Output
|
||||
Complete eval state saved to output file with:
|
||||
- Task IDs and correctness status
|
||||
- Prompts and extracted answers
|
||||
- Sampling configuration
|
||||
- Processing metadata
|
||||
1229
examples/llama-eval/llama-eval.py
Executable file
1229
examples/llama-eval/llama-eval.py
Executable file
File diff suppressed because it is too large
Load Diff
36
examples/llama-eval/llama-server-simulator-README.md
Normal file
36
examples/llama-eval/llama-server-simulator-README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# llama-server-simulator
|
||||
|
||||
Standalone Python script simulating llama-server HTTP endpoint for testing.
|
||||
|
||||
## Features
|
||||
|
||||
- HTTP Server with OpenAI-compatible `/v1/chat/completions` endpoint
|
||||
- AIME Dataset Integration - Loads 90 questions from HuggingFace
|
||||
- Intelligent Question Matching - Uses exact matching, LaTeX removal, and Levenshtein distance
|
||||
- Configurable Success Rate - Control correct/wrong answer generation (0-1)
|
||||
- Debug Logging - Troubleshoot matching issues
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
python llama-server-simulator.py --success-rate 0.8
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
- `--success-rate`: Probability of returning correct answer (0.0-1.0, default: 0.8)
|
||||
- `--port`: Server port (default: 8033)
|
||||
- `--debug`: Enable debug logging (default: False)
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
./test-simulator.sh
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
- Uses Levenshtein distance for partial matching (threshold: 0.3)
|
||||
- Automatic caching via HuggingFace datasets library
|
||||
- Wrong answers generated by incrementing expected answer
|
||||
- Debug output written to stderr
|
||||
283
examples/llama-eval/llama-server-simulator.py
Executable file
283
examples/llama-eval/llama-server-simulator.py
Executable file
@@ -0,0 +1,283 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
from flask import Flask, request, jsonify
|
||||
|
||||
# Set cache directory for HuggingFace datasets
|
||||
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
os.environ["HF_DATASETS_CACHE"] = str(cache_dir)
|
||||
|
||||
def dice(s1: str, s2: str) -> float:
|
||||
"""Calculate Dice coefficient between two strings based on bigram overlap."""
|
||||
if not s1 and not s2:
|
||||
return 1.0
|
||||
|
||||
def _bigrams(s: str):
|
||||
return [s[i : i + 2] for i in range(len(s) - 1)]
|
||||
|
||||
bigrams1 = _bigrams(s1)
|
||||
bigrams2 = _bigrams(s2)
|
||||
|
||||
if not bigrams1 and not bigrams2:
|
||||
return 1.0
|
||||
|
||||
from collections import Counter
|
||||
|
||||
freq1 = Counter(bigrams1)
|
||||
freq2 = Counter(bigrams2)
|
||||
|
||||
intersection = sum(min(freq1[bg], freq2[bg]) for bg in freq1)
|
||||
dice_coeff = 2 * intersection / (len(bigrams1) + len(bigrams2))
|
||||
return dice_coeff
|
||||
|
||||
def debug_log(message: str):
|
||||
"""Log debug messages to both stdout and a file"""
|
||||
print(message, file=sys.stderr)
|
||||
with open("/tmp/simulator-debug.log", "a") as f:
|
||||
f.write(message + "\n")
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@dataclass
|
||||
class EvalState:
|
||||
id: str
|
||||
tasks: List[str]
|
||||
task_states: Dict[str, Dict]
|
||||
sampling_config: Dict
|
||||
|
||||
def normalize_number(s: str) -> Optional[int]:
|
||||
match = re.match(r"\d+", s) # match digits from the start
|
||||
if not match:
|
||||
return None
|
||||
return int(match.group(0))
|
||||
|
||||
class AimeDataset:
|
||||
def __init__(self, split: str = "train"):
|
||||
self.split = split
|
||||
self.questions: List[Dict] = []
|
||||
self._load_dataset()
|
||||
|
||||
def _load_dataset(self):
|
||||
print(f"Loading AIME dataset (split: {self.split})...")
|
||||
|
||||
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
|
||||
if cache_path.exists():
|
||||
print(f"Using cached dataset from {cache_path}")
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
|
||||
else:
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
|
||||
|
||||
self.questions = list(ds)
|
||||
print(f"AIME dataset loaded: {len(self.questions)} questions")
|
||||
|
||||
def find_question(self, request_text: str) -> Optional[Dict]:
|
||||
best_match = None
|
||||
best_distance = -1
|
||||
best_index = -1
|
||||
|
||||
for i, question in enumerate(self.questions):
|
||||
question_text = question["problem"]
|
||||
request_lower = request_text.lower()
|
||||
question_lower = question_text.lower()
|
||||
|
||||
# Exact match
|
||||
if question_lower == request_lower:
|
||||
debug_log(f"DEBUG: Found exact match at index {i}")
|
||||
return question
|
||||
|
||||
# Remove LaTeX formatting for more flexible matching
|
||||
question_no_latex = re.sub(r'\$[^$]+\$', '', question_text)
|
||||
if question_no_latex.lower() == request_lower:
|
||||
debug_log(f"DEBUG: Found match (no LaTeX) at index {i}")
|
||||
return question
|
||||
|
||||
# Calculate Levenshtein distance for partial matches
|
||||
# Only consider if request is at least 50% of question length
|
||||
if len(request_lower) >= len(question_lower) * 0.5:
|
||||
distance = dice(question_lower, request_lower)
|
||||
|
||||
if distance > best_distance:
|
||||
best_distance = distance
|
||||
best_match = question
|
||||
best_index = i
|
||||
|
||||
if best_match and best_distance > 0.3: # Threshold for partial match
|
||||
debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}")
|
||||
return best_match
|
||||
|
||||
debug_log(f"DEBUG: No matching question found for: {request_text[:100]}...")
|
||||
return None
|
||||
|
||||
def get_answer(self, question: Dict) -> str:
|
||||
answer = question["answer"]
|
||||
if isinstance(answer, str):
|
||||
normalized = normalize_number(answer)
|
||||
return str(normalized) if normalized is not None else answer
|
||||
return str(answer)
|
||||
|
||||
class Simulator:
|
||||
def __init__(
|
||||
self,
|
||||
port: int = 8033,
|
||||
host: str = "localhost",
|
||||
success_rate: float = 0.8,
|
||||
dataset_split: str = "train"
|
||||
):
|
||||
self.port = port
|
||||
self.host = host
|
||||
self.success_rate = success_rate
|
||||
self.dataset = AimeDataset(dataset_split)
|
||||
self.eval_state = EvalState(
|
||||
id="aime-2025",
|
||||
tasks=["aime"],
|
||||
task_states={},
|
||||
sampling_config={"temperature": 0, "max_tokens": 2048}
|
||||
)
|
||||
|
||||
def _generate_response(
|
||||
self,
|
||||
question: Dict,
|
||||
should_be_correct: bool
|
||||
) -> Dict:
|
||||
expected_answer = self.dataset.get_answer(question)
|
||||
|
||||
if should_be_correct:
|
||||
response_text = expected_answer
|
||||
else:
|
||||
response_text = self._generate_wrong_answer(question)
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{int(time.time())}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": "llama",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response_text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150
|
||||
}
|
||||
}
|
||||
|
||||
def _generate_wrong_answer(self, question: Dict) -> str:
|
||||
expected_answer = self.dataset.get_answer(question)
|
||||
|
||||
if expected_answer.isdigit():
|
||||
wrong_answer = str(int(expected_answer) + 1)
|
||||
else:
|
||||
wrong_answer = expected_answer + " (wrong)"
|
||||
|
||||
return wrong_answer
|
||||
|
||||
def _process_request(self, request_data: Dict) -> Dict:
|
||||
messages = request_data.get("messages", [])
|
||||
if not messages:
|
||||
return {"error": "No messages in request"}
|
||||
|
||||
request_text = messages[0].get("content", "")
|
||||
debug_log(f"DEBUG: Received request with content: {request_text[:150]}...")
|
||||
|
||||
question = self.dataset.find_question(request_text)
|
||||
if not question:
|
||||
debug_log(f"DEBUG: find_question returned None")
|
||||
return {"error": "No matching question found"}
|
||||
|
||||
should_be_correct = random.random() < self.success_rate
|
||||
|
||||
response = self._generate_response(question, should_be_correct)
|
||||
|
||||
task_id = "aime"
|
||||
self.eval_state.task_states[task_id] = {
|
||||
"correct": should_be_correct,
|
||||
"expected": self.dataset.get_answer(question),
|
||||
"predicted": response["choices"][0]["message"]["content"]
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
@app.route('/v1/chat/completions', methods=['POST'])
|
||||
def chat_completions():
|
||||
try:
|
||||
request_data = request.get_json()
|
||||
|
||||
if not request_data:
|
||||
return jsonify({"error": "Invalid JSON"}), 400
|
||||
|
||||
response = simulator._process_request(request_data)
|
||||
|
||||
return jsonify(response)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing request: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="llama-server simulator for testing eval scripts"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8033,
|
||||
help="Server port (default: 8033)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Server host (default: localhost)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--success-rate",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="Success rate 0-1 (default: 0.8)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-split",
|
||||
type=str,
|
||||
default="train",
|
||||
help="AIME dataset split to use (default: train)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
global simulator
|
||||
simulator = Simulator(
|
||||
port=args.port,
|
||||
host=args.host,
|
||||
success_rate=args.success_rate,
|
||||
dataset_split=args.dataset_split
|
||||
)
|
||||
|
||||
print("\n=== llama-server-simulator ===")
|
||||
print(f"Server running on http://{args.host}:{args.port}")
|
||||
print(f"Success rate: {args.success_rate}")
|
||||
print(f"AIME dataset loaded: {len(simulator.dataset.questions)} questions")
|
||||
print("\nPress Ctrl+C to stop\n")
|
||||
|
||||
app.run(host=args.host, port=args.port, debug=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
86
examples/llama-eval/test-simulator.sh
Executable file
86
examples/llama-eval/test-simulator.sh
Executable file
@@ -0,0 +1,86 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Get the directory where this script is located
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
echo "=== llama-server-simulator Test Script ==="
|
||||
echo ""
|
||||
|
||||
PORT=8033
|
||||
SUCCESS_RATE=0.8
|
||||
TEST_PORT=8034
|
||||
|
||||
echo "Starting simulator on port $PORT with success rate $SUCCESS_RATE..."
|
||||
source "$SCRIPT_DIR/venv/bin/activate"
|
||||
python3 "$SCRIPT_DIR/llama-server-simulator.py" --port $PORT --success-rate $SUCCESS_RATE > /tmp/simulator-test.log 2>&1 &
|
||||
SIMULATOR_PID=$!
|
||||
|
||||
echo "Waiting for simulator to start..."
|
||||
sleep 5
|
||||
|
||||
# Helper function to make a request and extract the answer
|
||||
make_request() {
|
||||
local question="$1"
|
||||
curl -s -X POST http://localhost:$PORT/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"model\": \"llama\",
|
||||
\"messages\": [
|
||||
{\"role\": \"user\", \"content\": \"$question\"}
|
||||
],
|
||||
\"temperature\": 0,
|
||||
\"max_tokens\": 2048
|
||||
}" | python3 -c "import sys, json; data = json.load(sys.stdin); print(data.get('choices', [{}])[0].get('message', {}).get('content', data.get('error', 'No response')))"
|
||||
}
|
||||
|
||||
# Test question (repeated in multiple tests)
|
||||
TEST_QUESTION="Quadratic polynomials P(x) and Q(x) have leading coefficients 2 and -2, respectively. The graphs of both polynomials pass through the two points (16,54) and (20,53). Find P(0) + Q(0)."
|
||||
|
||||
echo ""
|
||||
echo "=== Test 1: Correct Answer ==="
|
||||
echo "Sending request with known question..."
|
||||
answer=$(make_request "$TEST_QUESTION")
|
||||
echo "Answer: $answer"
|
||||
echo "Expected: 116"
|
||||
echo "Correct: $([ "$answer" == "116" ] && echo "Yes" || echo "No")"
|
||||
|
||||
echo ""
|
||||
echo "=== Test 2: Wrong Answer ==="
|
||||
echo "Sending request with known question (success rate 0.0)..."
|
||||
answer=$(make_request "$TEST_QUESTION")
|
||||
echo "Answer: $answer"
|
||||
echo "Expected: 116"
|
||||
echo "Correct: $([ "$answer" == "116" ] && echo "Yes" || echo "No")"
|
||||
|
||||
echo ""
|
||||
echo "=== Test 3: No Matching Question ==="
|
||||
echo "Sending request with non-matching text..."
|
||||
response=$(make_request "What is the capital of France?")
|
||||
echo "Response: $response"
|
||||
echo "Expected: No matching question found"
|
||||
echo "Correct: $([ "$response" == "No matching question found" ] && echo "Yes" || echo "No")"
|
||||
|
||||
echo ""
|
||||
echo "=== Test 4: Success Rate Verification ==="
|
||||
echo "Sending 10 requests to test success rate..."
|
||||
correct_count=0
|
||||
for i in {1..10}; do
|
||||
answer=$(make_request "$TEST_QUESTION")
|
||||
if [ "$answer" == "116" ]; then
|
||||
correct_count=$((correct_count + 1))
|
||||
fi
|
||||
echo " Request $i: Answer = $answer"
|
||||
done
|
||||
echo "Correct answers: $correct_count/10"
|
||||
echo "Expected: ~8/10 (80% success rate)"
|
||||
echo "Success rate: $(echo "scale=1; $correct_count * 10" | bc)%"
|
||||
|
||||
echo ""
|
||||
echo "=== Test Complete ==="
|
||||
echo "Stopping simulator..."
|
||||
kill $SIMULATOR_PID 2>/dev/null
|
||||
wait $SIMULATOR_PID 2>/dev/null || true
|
||||
|
||||
echo "Simulator stopped."
|
||||
@@ -569,27 +569,24 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
endif()
|
||||
|
||||
# TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+
|
||||
# Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28
|
||||
FetchContent_Declare(KleidiAI_Download
|
||||
URL ${KLEIDIAI_DOWNLOAD_URL}
|
||||
DOWNLOAD_EXTRACT_TIMESTAMP NEW
|
||||
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
|
||||
|
||||
FetchContent_MakeAvailable(KleidiAI_Download)
|
||||
FetchContent_GetProperties(KleidiAI_Download
|
||||
SOURCE_DIR KLEIDIAI_SRC
|
||||
POPULATED KLEIDIAI_POPULATED)
|
||||
|
||||
if (NOT KLEIDIAI_POPULATED)
|
||||
message(FATAL_ERROR "KleidiAI source downloaded failed.")
|
||||
FetchContent_Populate(KleidiAI_Download)
|
||||
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
|
||||
endif()
|
||||
|
||||
add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
|
||||
|
||||
# Remove kleidiai target after fetching it
|
||||
if (TARGET kleidiai)
|
||||
set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
|
||||
endif()
|
||||
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
ggml-cpu/kleidiai/kleidiai.cpp
|
||||
ggml-cpu/kleidiai/kernels.cpp
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
#include "ggml-impl.h"
|
||||
#include "simd-mappings.h"
|
||||
|
||||
#define GGML_FA_TILE_Q 32
|
||||
#define GGML_FA_TILE_KV 16
|
||||
#define GGML_FA_TILE_Q 64
|
||||
#define GGML_FA_TILE_KV 64
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
|
||||
@@ -2874,8 +2874,8 @@ struct ggml_cplan ggml_graph_plan(
|
||||
const int64_t DV = node->src[2]->ne[0];
|
||||
|
||||
// Tiled flash attention scratch (tile sizes defined in common.h)
|
||||
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
|
||||
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
|
||||
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
|
||||
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks;
|
||||
|
||||
// Decode path: n_kv_chunks = n_tasks (one chunk per thread)
|
||||
// Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
|
||||
@@ -2947,7 +2947,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
/*.use_ref =*/ cplan->use_ref,
|
||||
};
|
||||
|
||||
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
|
||||
#ifdef GGML_USE_OPENMP
|
||||
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p\n", state->ith, (const void *)cplan);
|
||||
#else
|
||||
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
|
||||
#endif
|
||||
|
||||
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
|
||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||
@@ -2974,7 +2978,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
}
|
||||
}
|
||||
|
||||
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
|
||||
#ifdef GGML_USE_OPENMP
|
||||
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p\n", state->ith, (const void *)cplan);
|
||||
#else
|
||||
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
|
||||
#endif
|
||||
|
||||
ggml_barrier(state->threadpool);
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "binary-ops.h"
|
||||
#include "simd-gemm.h"
|
||||
#include "ggml.h"
|
||||
#include "unary-ops.h"
|
||||
#include "vec.h"
|
||||
@@ -8389,10 +8390,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
GGML_ASSERT(k->type == v->type);
|
||||
const ggml_type kv_type = k->type;
|
||||
|
||||
const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
|
||||
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
|
||||
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
|
||||
const size_t kv_type_size = ggml_type_size(kv_type);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t rk2 = neq2/nek2;
|
||||
@@ -8424,8 +8421,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
|
||||
GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
|
||||
|
||||
int ir = ir0;
|
||||
while (ir < ir1) {
|
||||
// q indices for the start of this tile
|
||||
@@ -8452,18 +8447,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
}
|
||||
|
||||
// Per-thread scratch layout:
|
||||
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
|
||||
// Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
|
||||
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
|
||||
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
|
||||
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
|
||||
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
|
||||
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
|
||||
// V32: KV_TILE_SZ * DV (F32 buffer for V tile)
|
||||
// K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
|
||||
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
|
||||
|
||||
void * Q_q = base;
|
||||
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
|
||||
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
|
||||
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
|
||||
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
|
||||
float * V32 = VKQ32 + Q_TILE_SZ * DV;
|
||||
float * K_f32 = V32 + KV_TILE_SZ * DV;
|
||||
|
||||
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
|
||||
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
||||
@@ -8476,28 +8473,38 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
const int iv3 = iq3 / rv3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
|
||||
}
|
||||
// Zero-pad remaining rows
|
||||
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
||||
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
|
||||
{
|
||||
float * Q_f32 = (float *)Q_q;
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
|
||||
}
|
||||
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
||||
memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
|
||||
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
|
||||
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
|
||||
|
||||
// skip the tile entirely if all the masks are -inf
|
||||
if (mask) {
|
||||
bool can_skip = true;
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
for (int tk = 0; tk < kv_tile; tk++) {
|
||||
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
|
||||
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
|
||||
can_skip = false;
|
||||
}
|
||||
}
|
||||
// Pad remaining mask entries with -inf
|
||||
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
|
||||
mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
if (can_skip) {
|
||||
@@ -8505,13 +8512,32 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
}
|
||||
}
|
||||
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
float s;
|
||||
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
|
||||
KQ[tq * KV_TILE_SZ + tk] = s * scale;
|
||||
// Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
|
||||
// Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
|
||||
for (int tk = 0; tk < kv_tile; tk++) {
|
||||
const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
|
||||
for (int64_t dk = 0; dk < DK; dk++) {
|
||||
K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
|
||||
}
|
||||
} else {
|
||||
const float * k_f32_src = (const float *)k_data;
|
||||
for (int64_t dk = 0; dk < DK; dk++) {
|
||||
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
|
||||
}
|
||||
}
|
||||
}
|
||||
memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
||||
simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
|
||||
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
|
||||
|
||||
// Set padded KQ entries to -inf so softmax gives them zero weight
|
||||
if (kv_tile < KV_TILE_SZ) {
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
|
||||
KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8551,33 +8577,22 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
|
||||
}
|
||||
|
||||
// Convert V tile to F32 first (if F16), then do MAD
|
||||
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
|
||||
// TODO: on ARM, native f16 should be faster
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
|
||||
}
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) continue;
|
||||
float * vkq_row = VKQ32 + tq * DV;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const float p = KQ[tq * KV_TILE_SZ + tk];
|
||||
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) continue;
|
||||
float * vkq_row = VKQ32 + tq * DV;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const float p = KQ[tq * KV_TILE_SZ + tk];
|
||||
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
|
||||
}
|
||||
// V accumulation: VKQ32 += softmax(KQ) * V
|
||||
// Pack V tile to contiguous F32, zero-padded
|
||||
for (int tk = 0; tk < kv_tile; tk++) {
|
||||
const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
|
||||
} else {
|
||||
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
|
||||
}
|
||||
}
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) {
|
||||
memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
|
||||
}
|
||||
}
|
||||
simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
|
||||
}
|
||||
|
||||
// sinks (apply only to valid rows in the tile)
|
||||
@@ -8794,15 +8809,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
|
||||
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
||||
|
||||
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
const bool use_tiled = !use_ref &&
|
||||
bool use_tiled = !use_ref &&
|
||||
(q->type == GGML_TYPE_F32 &&
|
||||
kv_is_f32_or_f16 &&
|
||||
k->type == v->type &&
|
||||
nek1 % KV_TILE_SZ == 0 &&
|
||||
neq1 >= Q_TILE_SZ);
|
||||
|
||||
#ifdef GGML_SIMD
|
||||
use_tiled &= (DV % GGML_F32_EPR == 0);
|
||||
#endif
|
||||
int current_chunk = ith;
|
||||
|
||||
while (current_chunk < nchunk) {
|
||||
|
||||
136
ggml/src/ggml-cpu/simd-gemm.h
Normal file
136
ggml/src/ggml-cpu/simd-gemm.h
Normal file
@@ -0,0 +1,136 @@
|
||||
#pragma once
|
||||
|
||||
// Computes C[M x N] += A[M x K] * B[K x N]
|
||||
|
||||
#include "simd-mappings.h"
|
||||
|
||||
// TODO: add support for sizeless vector types
|
||||
#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)
|
||||
|
||||
// TODO: untested on avx512
|
||||
// These are in units of GGML_F32_EPR
|
||||
#if defined(__AVX512F__) || defined (__ARM_NEON__)
|
||||
static constexpr int GEMM_RM = 4;
|
||||
static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
|
||||
#elif defined(__AVX2__) || defined(__AVX__)
|
||||
static constexpr int GEMM_RM = 6;
|
||||
static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16
|
||||
#else
|
||||
static constexpr int GEMM_RM = 2;
|
||||
static constexpr int GEMM_RN = 2;
|
||||
#endif
|
||||
|
||||
template <int RM, int RN>
|
||||
static inline void simd_gemm_ukernel(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int K, int N)
|
||||
{
|
||||
static constexpr int KN = GGML_F32_EPR;
|
||||
|
||||
GGML_F32_VEC acc[RM][RN];
|
||||
for (int64_t i = 0; i < RM; i++) {
|
||||
for (int r = 0; r < RN; r++) {
|
||||
acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN);
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
GGML_F32_VEC Bv[RN];
|
||||
for (int r = 0; r < RN; r++) {
|
||||
Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN);
|
||||
}
|
||||
for (int64_t i = 0; i < RM; i++) {
|
||||
GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]);
|
||||
for (int r = 0; r < RN; r++) {
|
||||
acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < RM; i++) {
|
||||
for (int r = 0; r < RN; r++) {
|
||||
GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// C[M x N] += A[M x K] * B[K x N]
|
||||
static void simd_gemm(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int M, int K, int N)
|
||||
{
|
||||
static constexpr int KN = GGML_F32_EPR;
|
||||
|
||||
int64_t ii = 0;
|
||||
for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
|
||||
int64_t jj = 0;
|
||||
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
||||
simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj + KN <= N; jj += KN) {
|
||||
simd_gemm_ukernel<GEMM_RM, 1>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj < N; jj++) {
|
||||
for (int64_t i = 0; i < GEMM_RM; i++) {
|
||||
float a = C[i * N + jj];
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
a += A[i + kk] * B[kk * N + jj];
|
||||
}
|
||||
C[i * N + jj] = a;
|
||||
}
|
||||
}
|
||||
|
||||
A += GEMM_RM * K;
|
||||
C += GEMM_RM * N;
|
||||
}
|
||||
|
||||
// Tail rows: one at a time
|
||||
for (; ii < M; ii++) {
|
||||
int64_t jj = 0;
|
||||
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
||||
simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj + KN <= N; jj += KN) {
|
||||
simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj < N; jj++) {
|
||||
float a = C[jj];
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
a += A[kk] * B[kk * N + jj];
|
||||
}
|
||||
C[jj] = a;
|
||||
}
|
||||
|
||||
A += K;
|
||||
C += N;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__GNUC__) && !defined(__clang__)
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#else // scalar path
|
||||
|
||||
static void simd_gemm(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int M, int K, int N)
|
||||
{
|
||||
for (int64_t i = 0; i < M; i++) {
|
||||
for (int64_t j = 0; j < N; j++) {
|
||||
float sum = C[i * N + j];
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
sum += A[i * K + kk] * B[kk * N + j];
|
||||
}
|
||||
C[i * N + j] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // GGML_SIMD
|
||||
@@ -1160,6 +1160,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
||||
float32x4_t tmp = x[0] + vec_reve(x[0]); \
|
||||
res = tmp[0] + tmp[1]; \
|
||||
}
|
||||
#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \
|
||||
{ \
|
||||
float32x4_t v = vec_add(vec_add(s0, s1), \
|
||||
vec_add(s2, s3)); \
|
||||
v = vec_add(v, vec_sld(v, v, 8)); \
|
||||
v = vec_add(v, vec_sld(v, v, 4)); \
|
||||
res += (ggml_float)vec_extract(v, 0); \
|
||||
}
|
||||
|
||||
#define GGML_F32_VEC GGML_F32x4
|
||||
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
|
||||
@@ -1209,6 +1217,24 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
|
||||
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
|
||||
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
|
||||
|
||||
// BF16 s390x
|
||||
#define GGML_BF16_STEP 16
|
||||
#define GGML_BF16_EPR 8
|
||||
|
||||
#define GGML_BF16x8 __vector unsigned short
|
||||
#define GGML_BF16x8_ZERO vec_splats((unsigned short)0)
|
||||
#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p))
|
||||
|
||||
#define GGML_BF16_VEC GGML_BF16x8
|
||||
#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO
|
||||
#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD
|
||||
#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO))
|
||||
#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO))
|
||||
#define GGML_BF16_FMA_LO(acc, x, y) \
|
||||
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y))
|
||||
#define GGML_BF16_FMA_HI(acc, x, y) \
|
||||
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y))
|
||||
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
|
||||
// compatible with vlen >= 128
|
||||
|
||||
@@ -236,8 +236,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
|
||||
vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||
sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
|
||||
|
||||
#endif
|
||||
#if defined(__POWER9_VECTOR__)
|
||||
#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__)
|
||||
const int np = (n & ~(GGML_BF16_STEP - 1));
|
||||
if (np > 0) {
|
||||
GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO};
|
||||
|
||||
@@ -2872,6 +2872,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
|
||||
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
|
||||
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
|
||||
const std::string delta_net_prefix = "dnet_add";
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
@@ -2902,7 +2903,8 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
|
||||
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
|
||||
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0 &&
|
||||
strncmp(node->name, delta_net_prefix.c_str(), delta_net_prefix.size()) != 0) {
|
||||
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
||||
// by means of matching node names. See
|
||||
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
||||
@@ -4544,6 +4546,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
// TODO: should become:
|
||||
//return ggml_is_contiguous_rows(op->src[0]);
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -2715,14 +2715,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR2_XXS; ++l) {
|
||||
const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
|
||||
const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
|
||||
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]];
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7 * l));
|
||||
|
||||
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
|
||||
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
||||
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
|
||||
@@ -2733,12 +2733,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
const int ls = aux32 >> 28;
|
||||
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
|
||||
const float d = bxi->d;
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
|
||||
#else
|
||||
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
|
||||
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
@@ -2776,11 +2776,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR2_XS; ++l) {
|
||||
const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
|
||||
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
|
||||
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF];
|
||||
const uint32_t signs = unpack_ksigns(q2[l] >> 9);
|
||||
|
||||
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
||||
@@ -2904,11 +2907,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR3_XXS; ++l) {
|
||||
const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7*l));
|
||||
|
||||
const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
|
||||
|
||||
@@ -94,6 +94,15 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
|
||||
#endif
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) {
|
||||
// v is a 7 bit int, with the 8th sign being encodable as popcnt
|
||||
// with xor we can "correct" the bit instead of having to mask
|
||||
const uint32_t p = __popc(v) & 1;
|
||||
const uint32_t s = v ^ p << 7;
|
||||
// broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors
|
||||
return s * 0x01010101;
|
||||
}
|
||||
|
||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||
|
||||
@@ -905,22 +914,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
||||
int sumi = 0;
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < 8; k0 += 2) {
|
||||
const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
|
||||
const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
|
||||
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]];
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2));
|
||||
|
||||
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
|
||||
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
|
||||
sumi = ggml_cuda_dp4a(grid0, u0, sumi);
|
||||
|
||||
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
||||
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
|
||||
sumi = ggml_cuda_dp4a(grid1, u1, sumi);
|
||||
}
|
||||
|
||||
const int ls = aux32 >> 28;
|
||||
sumi = (ls*sumi + sumi/2)/4;
|
||||
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
|
||||
sumi = sumi * ls / 8; // (sumi * scale + sumi / 2) / 4
|
||||
const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
|
||||
return d * sumi;
|
||||
}
|
||||
@@ -942,13 +951,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
||||
int sumi1 = 0;
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < 8; l0 += 2) {
|
||||
const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
|
||||
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9));
|
||||
|
||||
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
||||
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF];
|
||||
const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9);
|
||||
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
|
||||
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
|
||||
|
||||
if (l0 < 4) {
|
||||
@@ -1028,13 +1039,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < 8; l0 += 2) {
|
||||
const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2));
|
||||
|
||||
const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
|
||||
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
|
||||
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
|
||||
|
||||
sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
|
||||
|
||||
@@ -273,6 +273,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_GLU:
|
||||
case GGML_OP_SCALE:
|
||||
|
||||
@@ -3830,6 +3830,7 @@ class VisionProjectorType:
|
||||
MUSIC_FLAMINGO = "musicflamingo" # audio
|
||||
GLM4V = "glm4v"
|
||||
YOUTUVL = "youtuvl"
|
||||
NEMOTRON_V2_VL = "nemotron_v2_vl"
|
||||
|
||||
|
||||
# Items here are (block size, type size)
|
||||
|
||||
@@ -1346,6 +1346,7 @@ class TensorNameMap:
|
||||
"model.vision_tower.embeddings.cls_token", # Intern-S1
|
||||
"vision_model.class_embedding", # llama 4
|
||||
"model.vision.patch_embedding.cls_embedding", # cogvlm
|
||||
"vision_model.radio_model.model.patch_generator.cls_token.token", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH: (
|
||||
@@ -1360,6 +1361,7 @@ class TensorNameMap:
|
||||
"vision_tower.patch_embed.proj", # kimi-vl
|
||||
"model.vision.patch_embedding.proj", # cogvlm
|
||||
"siglip2.vision_model.embeddings.patch_embedding",
|
||||
"vision_model.radio_model.model.patch_generator.embedder", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_NORM: (
|
||||
@@ -1376,12 +1378,14 @@ class TensorNameMap:
|
||||
"visual.pos_embed", # qwen3vl
|
||||
"model.vision.patch_embedding.position_embedding", # cogvlm
|
||||
"visual.embeddings.position_embedding", # glm4v
|
||||
"vision_model.radio_model.model.patch_generator.pos_embed", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_QKV: (
|
||||
"visual.blocks.{bid}.attn.qkv", # qwen3vl
|
||||
"model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
|
||||
"vision_tower.encoder.blocks.{bid}.wqkv" # Kimi-K2.5
|
||||
"vision_tower.encoder.blocks.{bid}.wqkv", # Kimi-K2.5
|
||||
"vision_model.radio_model.model.blocks.{bid}.attn.qkv", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q: (
|
||||
@@ -1446,6 +1450,7 @@ class TensorNameMap:
|
||||
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
|
||||
"model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
|
||||
"siglip2.vision_model.encoder.layers.{bid}.layer_norm1",
|
||||
"vision_model.radio_model.model.blocks.{bid}.norm1", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_O: (
|
||||
@@ -1462,6 +1467,7 @@ class TensorNameMap:
|
||||
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
|
||||
"model.vision.transformer.layers.{bid}.attention.dense", # cogvlm
|
||||
"siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl
|
||||
"vision_model.radio_model.model.blocks.{bid}.attn.proj", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
|
||||
@@ -1477,6 +1483,7 @@ class TensorNameMap:
|
||||
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
|
||||
"model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm
|
||||
"siglip2.vision_model.encoder.layers.{bid}.layer_norm2",
|
||||
"vision_model.radio_model.model.blocks.{bid}.norm2", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_UP: (
|
||||
@@ -1493,6 +1500,7 @@ class TensorNameMap:
|
||||
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
|
||||
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
|
||||
"siglip2.vision_model.encoder.layers.{bid}.mlp.fc1",
|
||||
"vision_model.radio_model.model.blocks.{bid}.mlp.fc1", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_GATE: (
|
||||
@@ -1515,6 +1523,7 @@ class TensorNameMap:
|
||||
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
|
||||
"model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
|
||||
"siglip2.vision_model.encoder.layers.{bid}.mlp.fc2",
|
||||
"vision_model.radio_model.model.blocks.{bid}.mlp.fc2", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_LAYER_SCALE_1: (
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
HTTPLIB_VERSION = "f80864ca031932351abef49b74097c67f14719c6"
|
||||
HTTPLIB_VERSION = "d4180e923f846b44a3d30acd938438d6e64fc9f6"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
|
||||
@@ -878,6 +878,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
// fallback to full vocab list
|
||||
GGML_UNUSED(err);
|
||||
}
|
||||
|
||||
return sampling.token_ids_full_vocab.data();
|
||||
@@ -1809,7 +1810,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
//
|
||||
|
||||
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & vocab = model.vocab;
|
||||
|
||||
@@ -1893,11 +1893,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
|
||||
offset += embd.size * sizeof(float);
|
||||
|
||||
sampling.logits = {nullptr, 0};
|
||||
sampling.probs = {nullptr, 0};
|
||||
sampling.sampled = {nullptr, 0};
|
||||
sampling.candidates = {nullptr, 0};
|
||||
|
||||
if (has_sampling) {
|
||||
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
||||
offset += sampling.logits.size * sizeof(float);
|
||||
@@ -1923,6 +1918,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
|
||||
|
||||
std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
|
||||
} else {
|
||||
sampling.logits = {nullptr, 0};
|
||||
sampling.probs = {nullptr, 0};
|
||||
sampling.sampled = {nullptr, 0};
|
||||
sampling.candidates = {nullptr, 0};
|
||||
|
||||
sampling.logits_count.clear();
|
||||
sampling.probs_count.clear();
|
||||
sampling.candidates_count.clear();
|
||||
}
|
||||
|
||||
// set all ids as invalid (negative)
|
||||
@@ -1953,37 +1957,30 @@ void llama_context::output_reorder() {
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.logits.has_data()) {
|
||||
if (!sampling.samplers.empty()) {
|
||||
assert(sampling.logits.size > 0);
|
||||
assert(sampling.probs.size > 0);
|
||||
assert(sampling.candidates.size > 0);
|
||||
assert(sampling.sampled.size > 0);
|
||||
assert(sampling.logits_count.size() > 0);
|
||||
assert(sampling.probs_count.size() > 0);
|
||||
assert(sampling.candidates_count.size() > 0);
|
||||
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.probs.has_data()) {
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.candidates.has_data()) {
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.sampled.has_data()) {
|
||||
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
|
||||
}
|
||||
|
||||
if (!sampling.logits_count.empty()) {
|
||||
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
||||
}
|
||||
|
||||
if (!sampling.probs_count.empty()) {
|
||||
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
||||
}
|
||||
|
||||
if (!sampling.candidates_count.empty()) {
|
||||
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
|
||||
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
||||
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
||||
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,24 +265,26 @@ private:
|
||||
std::unique_ptr<llama_memory_i> memory;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
struct buffer_view<float> logits = {nullptr, 0};
|
||||
buffer_view<float> logits = {nullptr, 0};
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
struct buffer_view<float> embd = {nullptr, 0};
|
||||
buffer_view<float> embd = {nullptr, 0};
|
||||
|
||||
struct sampling_info {
|
||||
// !samplers.empty() to check if any samplers are active
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
struct buffer_view<float> logits = {nullptr, 0};
|
||||
struct buffer_view<llama_token> sampled = {nullptr, 0};
|
||||
struct buffer_view<float> probs = {nullptr, 0};
|
||||
struct buffer_view<llama_token> candidates = {nullptr, 0};
|
||||
buffer_view<float> logits = {nullptr, 0};
|
||||
buffer_view<llama_token> sampled = {nullptr, 0};
|
||||
buffer_view<float> probs = {nullptr, 0};
|
||||
buffer_view<llama_token> candidates = {nullptr, 0};
|
||||
|
||||
std::vector<uint32_t> logits_count;
|
||||
std::vector<uint32_t> probs_count;
|
||||
std::vector<uint32_t> candidates_count;
|
||||
|
||||
// optimization
|
||||
std::vector<llama_token> token_ids_full_vocab;
|
||||
};
|
||||
|
||||
|
||||
@@ -489,9 +489,6 @@ private:
|
||||
ggml_tensor * build_layer_attn_linear(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_layer_ffn(
|
||||
@@ -506,9 +503,6 @@ private:
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il);
|
||||
|
||||
// returns pair of output and new state
|
||||
|
||||
@@ -16,17 +16,6 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
ggml_tensor * causal_mask =
|
||||
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
|
||||
GGML_TRI_TYPE_LOWER);
|
||||
|
||||
ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
|
||||
ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
|
||||
|
||||
ggml_build_forward_expand(gf, causal_mask);
|
||||
ggml_build_forward_expand(gf, identity);
|
||||
ggml_build_forward_expand(gf, diag_mask);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
@@ -36,7 +25,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
|
||||
// Determine layer type and build appropriate attention mechanism
|
||||
if (hparams.is_recurrent(il)) {
|
||||
// Linear attention layer (gated delta net)
|
||||
cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
|
||||
cur = build_layer_attn_linear(inp->get_recr(), cur, il);
|
||||
} else {
|
||||
// Full attention layer
|
||||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
|
||||
@@ -99,11 +88,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
ggml_tensor * b,
|
||||
ggml_tensor * s,
|
||||
int il) {
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
@@ -113,134 +99,123 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
|
||||
GGML_ASSERT(S_k == S_v);
|
||||
GGML_ASSERT(H_v % H_k == 0);
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
||||
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
||||
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
q = ggml_l2_norm(ctx0, q, eps_norm);
|
||||
k = ggml_l2_norm(ctx0, k, eps_norm);
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_v);
|
||||
const float scale = 1.0f / sqrtf(S_k);
|
||||
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
|
||||
beta = ggml_sigmoid(ctx0, beta);
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
cb(v, "v_in", il);
|
||||
cb(beta, "beta_in", il);
|
||||
cb(b, "b_in", il);
|
||||
cb(g, "g_in", il);
|
||||
|
||||
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
|
||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
|
||||
g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [ 1, n_tokens, H_v, n_seqs]
|
||||
b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [ 1, n_tokens, H_v, n_seqs]
|
||||
|
||||
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
|
||||
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
|
||||
const int CS = CHUNK_SIZE;
|
||||
|
||||
cb(q, "q_perm", il);
|
||||
cb(k, "k_perm", il);
|
||||
cb(v, "v_perm", il);
|
||||
cb(beta, "beta_perm", il);
|
||||
cb(g, "g_perm", il);
|
||||
cb(state, "state_in", il);
|
||||
|
||||
GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
|
||||
|
||||
// Do padding
|
||||
const int64_t chunk_size = CHUNK_SIZE;
|
||||
|
||||
const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
|
||||
const int64_t n_chunks = (n_tokens + pad) / chunk_size;
|
||||
const int pad = (CS - n_tokens % CS) % CS;
|
||||
const int n_chunks = (n_tokens + pad) / CS;
|
||||
|
||||
q = ggml_pad(ctx0, q, 0, pad, 0, 0);
|
||||
k = ggml_pad(ctx0, k, 0, pad, 0, 0);
|
||||
v = ggml_pad(ctx0, v, 0, pad, 0, 0);
|
||||
g = ggml_pad(ctx0, g, pad, 0, 0, 0);
|
||||
beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
|
||||
g = ggml_pad(ctx0, g, 0, pad, 0, 0);
|
||||
b = ggml_pad(ctx0, b, 0, pad, 0, 0);
|
||||
|
||||
cb(q, "q_pad", il);
|
||||
cb(k, "k_pad", il);
|
||||
cb(v, "v_pad", il);
|
||||
cb(beta, "beta_pad", il);
|
||||
cb(g, "g_pad", il);
|
||||
ggml_tensor * v_b = ggml_mul(ctx0, v, b);
|
||||
ggml_tensor * k_b = ggml_mul(ctx0, k, b);
|
||||
|
||||
ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
|
||||
ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
|
||||
cb(v_b, "v_b", il);
|
||||
cb(k_b, "k_b", il);
|
||||
|
||||
cb(v_beta, "v_beta", il);
|
||||
cb(k_beta, "k_beta", il);
|
||||
q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs);
|
||||
k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs);
|
||||
k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs);
|
||||
v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs);
|
||||
v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
|
||||
|
||||
q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs);
|
||||
k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs);
|
||||
k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
|
||||
v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs);
|
||||
v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
|
||||
g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs);
|
||||
b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs);
|
||||
|
||||
g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
|
||||
beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
|
||||
// [CS, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_cs = ggml_cumsum(ctx0, g);
|
||||
cb(g_cs, "g_cs", il);
|
||||
|
||||
ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
|
||||
cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
|
||||
ggml_tensor * g_cs_i = g_cs;
|
||||
ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
|
||||
|
||||
ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
|
||||
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
|
||||
g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
|
||||
|
||||
ggml_tensor * gcs_j_broadcast =
|
||||
ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
|
||||
|
||||
ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
|
||||
cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
|
||||
// [CS, CS, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * decay_mask;
|
||||
decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
|
||||
decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
|
||||
decay_mask = ggml_exp(ctx0, decay_mask);
|
||||
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
|
||||
cb(decay_mask, "decay_mask", il);
|
||||
|
||||
ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
|
||||
// [CS, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * kb;
|
||||
kb = ggml_mul_mat(ctx0, k, k_b);
|
||||
kb = ggml_mul (ctx0, kb, decay_mask);
|
||||
|
||||
ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
|
||||
ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
|
||||
cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
// [CS, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * attn;
|
||||
attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
|
||||
|
||||
ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
|
||||
ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
|
||||
ggml_tensor * identity;
|
||||
identity = ggml_view_1d(ctx0, attn, CS, 0);
|
||||
identity = ggml_fill (ctx0, identity, 1.0f);
|
||||
identity = ggml_diag (ctx0, identity);
|
||||
|
||||
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
|
||||
attn = ggml_mul(ctx0, lin_solve, causal_mask);
|
||||
attn = ggml_add(ctx0, attn, identity);
|
||||
cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
ggml_tensor * lhs = ggml_add(ctx0, attn, identity);
|
||||
cb(lhs, "dnet_add_ch_lhs", il);
|
||||
|
||||
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
|
||||
attn = ggml_neg(ctx0, attn);
|
||||
|
||||
ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
|
||||
ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t);
|
||||
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
|
||||
attn = ggml_add(ctx0, lin_solve, identity);
|
||||
cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs]
|
||||
|
||||
ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
|
||||
cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
|
||||
// [S_v, CS, n_chunks, H_v * n_seqs]
|
||||
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
|
||||
|
||||
ggml_tensor * k_cumdecay =
|
||||
ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
|
||||
cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
// [CS, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
|
||||
|
||||
ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
|
||||
attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
|
||||
attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
|
||||
cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
|
||||
|
||||
// [CS, S_k, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp);
|
||||
cb(kbg, "k_beta_g_exp", il);
|
||||
|
||||
// [S_k, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
|
||||
cb(k_cd, "k_cumdecay", il);
|
||||
|
||||
// [S_k, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp);
|
||||
ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
|
||||
|
||||
// [CS, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
kq = ggml_mul(ctx0, kq, decay_mask);
|
||||
kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
// vectorized calculation of key_gdiff
|
||||
// improved from the chunked version:
|
||||
@@ -250,109 +225,98 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
|
||||
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
|
||||
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
|
||||
|
||||
// get last element in g_cumsum along chunk_size dimension (ne0)
|
||||
// get last element in g_cumsum along CS dimension (ne0)
|
||||
// example: [[x, y, z, ..., last], ...] -> [[last], ...]
|
||||
ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
|
||||
g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
|
||||
(g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
|
||||
// [1, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3],
|
||||
g_cs->nb[1],
|
||||
g_cs->nb[2],
|
||||
g_cs->nb[3],
|
||||
ggml_row_size(g_cs->type, g_cs->ne[0] - 1));
|
||||
cb(g_last, "g_last", il);
|
||||
|
||||
// TODO: remove this cont when CUDA supports non-cont unary ops
|
||||
g_last = ggml_cont(ctx0, g_last);
|
||||
cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
|
||||
|
||||
// [1, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
|
||||
cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
|
||||
cb(g_last_exp, "g_last_exp", il);
|
||||
|
||||
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
|
||||
cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
|
||||
// [CS, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
|
||||
cb(g_diff, "g_diff", il);
|
||||
|
||||
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
|
||||
ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
|
||||
1, chunk_size, n_chunks, g_diff_exp->ne[3]);
|
||||
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
|
||||
ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp);
|
||||
|
||||
ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
|
||||
cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
|
||||
// [S_k, CS, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
|
||||
cb(kg, "key_gdiff", il);
|
||||
|
||||
ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
|
||||
cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
|
||||
// [CS, S_k, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
|
||||
cb(kg_t, "key_gdiff_t", il);
|
||||
|
||||
ggml_tensor * s_t = ggml_transpose(ctx0, s);
|
||||
s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
|
||||
cb(s_t, "dnet_add_ch_state", il);
|
||||
|
||||
// state to be updated per chunk
|
||||
ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
|
||||
cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
|
||||
|
||||
// shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
|
||||
ggml_tensor * core_attn_out = nullptr;
|
||||
// [CS, S_v, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
|
||||
|
||||
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
|
||||
// shape: (S_k, chunk_size, 1, H_k * n_seqs)
|
||||
ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
|
||||
ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs]
|
||||
ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs]
|
||||
ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs]
|
||||
ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs]
|
||||
ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs]
|
||||
|
||||
// shape: (S_v, chunk_size, 1, H_v * n_seqs)
|
||||
ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat
|
||||
// [CS, S_v, 1, H_v * n_seqs]
|
||||
ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
|
||||
cb(v_t_p, "v_prime", il);
|
||||
|
||||
// shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
|
||||
ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul
|
||||
// [CS, S_v, 1, H_v * n_seqs]
|
||||
ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p);
|
||||
cb(v_t_new, "v_t_new", il);
|
||||
|
||||
// shape: (chunk_size, 1, H_v * n_seqs)
|
||||
ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat
|
||||
// [S_v, CS, 1, H_v * n_seqs]
|
||||
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq);
|
||||
cb(v_attn, "v_attn", il);
|
||||
|
||||
// attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
// replaced by precomputed attn_kq
|
||||
ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
|
||||
cb(attn_chunk, "attn_chunk", il);
|
||||
// [S_v, CS, 1, H_v * n_seqs]
|
||||
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
|
||||
cb(attn_inter, "attn_inter", il);
|
||||
|
||||
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
|
||||
// [S_v, CS, 1, H_v * n_seqs]
|
||||
ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn);
|
||||
cb(o_ch, "dnet_add_ch_attn_out", il);
|
||||
|
||||
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
|
||||
cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
|
||||
|
||||
// v_new = v_i - v_prime
|
||||
ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
|
||||
ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
|
||||
cb(v_new, "v_new_chunk", il);
|
||||
|
||||
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk);
|
||||
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
|
||||
cb(attn_inter, "attn_inter_chunk", il);
|
||||
|
||||
// core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
|
||||
cb(v_attn, "v_attn_chunk", il);
|
||||
|
||||
ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
|
||||
cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs)
|
||||
|
||||
core_attn_out = core_attn_out == nullptr
|
||||
? core_attn_out_chunk
|
||||
: ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
|
||||
v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]);
|
||||
|
||||
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
|
||||
ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
|
||||
//ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
|
||||
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
|
||||
// TODO: head broadcast might not work here - probably will need a transpose
|
||||
ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs]
|
||||
|
||||
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
|
||||
ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
|
||||
new_state = ggml_add(ctx0,
|
||||
ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
|
||||
ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
|
||||
ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk);
|
||||
s_t = ggml_mul(ctx0, s_t, ch_g_last_exp);
|
||||
s_t = ggml_add(ctx0, s_t, kgv);
|
||||
cb(s_t, "dnet_add_ch_state", il);
|
||||
}
|
||||
|
||||
s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
|
||||
|
||||
// truncate padded tokens
|
||||
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
|
||||
ggml_tensor * o = ggml_view_4d(ctx0, v,
|
||||
S_v, n_tokens, H_v, n_seqs,
|
||||
ggml_row_size(core_attn_out->type, S_v),
|
||||
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
|
||||
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
|
||||
output_tokens = ggml_cont(ctx0, output_tokens);
|
||||
cb(output_tokens, "output_tokens", il);
|
||||
ggml_row_size(v->type, S_v),
|
||||
ggml_row_size(v->type, S_v * CS * n_chunks),
|
||||
ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
|
||||
|
||||
// permute back to (S_v, H_v, n_tokens, n_seqs)
|
||||
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
|
||||
output_tokens = ggml_cont(ctx0, output_tokens);
|
||||
o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
|
||||
s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]
|
||||
|
||||
return {output_tokens, new_state};
|
||||
return {o, s};
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_autoregressive(
|
||||
@@ -360,8 +324,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_aut
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
ggml_tensor * b, // beta
|
||||
ggml_tensor * s, // state
|
||||
int il) {
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
@@ -371,75 +335,72 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_aut
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing
|
||||
GGML_ASSERT(v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
|
||||
GGML_ASSERT(n_tokens == 1);
|
||||
|
||||
GGML_ASSERT(S_k == S_v);
|
||||
GGML_ASSERT(H_v % H_k == 0);
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
||||
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
||||
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
const float scale = 1.0f / sqrtf(S_k);
|
||||
|
||||
q = ggml_l2_norm(ctx0, q, eps_norm);
|
||||
k = ggml_l2_norm(ctx0, k, eps_norm);
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_v);
|
||||
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
beta = ggml_sigmoid(ctx0, beta);
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
|
||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
cb(v, "v_in", il);
|
||||
cb(beta, "beta_in", il);
|
||||
cb(b, "b_in", il);
|
||||
cb(g, "g_in", il);
|
||||
|
||||
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
|
||||
g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs);
|
||||
b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs);
|
||||
|
||||
ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
|
||||
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
|
||||
// [S_v, S_v, H_v, n_seqs]
|
||||
g = ggml_exp(ctx0, g);
|
||||
s = ggml_mul(ctx0, s, g);
|
||||
|
||||
// Apply exponential to g_t
|
||||
g_t = ggml_exp(ctx0, g_t);
|
||||
ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
|
||||
|
||||
// Apply the gated delta rule for the single timestep
|
||||
// last_recurrent_state = last_recurrent_state * g_t
|
||||
state = ggml_mul(ctx0, state, g_t);
|
||||
// [1, S_v, H_v, n_seqs]
|
||||
ggml_tensor * sk;
|
||||
sk = ggml_mul (ctx0, s_t, k);
|
||||
sk = ggml_sum_rows(ctx0, sk);
|
||||
|
||||
// kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
|
||||
ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
|
||||
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
|
||||
// we need to sum over dim=-2, so we transpose, sum, then transpose again
|
||||
kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
|
||||
// [S_v, 1, H_v, n_seqs]
|
||||
ggml_tensor * d;
|
||||
d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk));
|
||||
d = ggml_mul(ctx0, d, b);
|
||||
|
||||
// v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
|
||||
ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
|
||||
// delta = (v_t - kv_mem) * beta_t
|
||||
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs]
|
||||
ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);
|
||||
// [1, S_v, H_v, n_seqs]
|
||||
ggml_tensor * d_t;
|
||||
d_t = ggml_transpose(ctx0, d);
|
||||
|
||||
// last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
|
||||
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
|
||||
state = ggml_add(ctx0, state, k_t_delta);
|
||||
// [S_v, S_v, H_v, n_seqs]
|
||||
ggml_tensor * kd;
|
||||
k = ggml_repeat(ctx0, k, s);
|
||||
kd = ggml_mul (ctx0, k, d_t);
|
||||
|
||||
// Compute the attention output
|
||||
// core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
|
||||
ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t
|
||||
ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
|
||||
// again, since it's over dim = -2, transpose, sum, transpose back
|
||||
ggml_tensor * core_attn_out =
|
||||
ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
|
||||
s_t = ggml_add(ctx0, s_t, kd);
|
||||
|
||||
// core_attn_out should be [S_v, 1, H_v, n_seqs] after this
|
||||
cb(core_attn_out, "output_tokens", il);
|
||||
cb(state, "new_state", il);
|
||||
cb(s_t, "dnet_add_ar_state", il);
|
||||
|
||||
return {core_attn_out, state};
|
||||
ggml_tensor * s_q = ggml_mul (ctx0, s_t, q);
|
||||
ggml_tensor * o = ggml_sum_rows(ctx0, s_q);
|
||||
|
||||
o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
|
||||
s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]
|
||||
|
||||
return {o, s};
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_qwen3next::build_norm_gated(
|
||||
@@ -472,39 +433,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
|
||||
// Split Q projection into query and gate
|
||||
// The split should be along dimension 0 (the feature dimension)
|
||||
ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
|
||||
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
|
||||
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
|
||||
cb(Qcur, "Qcur_view", il);
|
||||
|
||||
ggml_tensor * gate =
|
||||
ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
|
||||
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(gate, "gate", il);
|
||||
|
||||
// Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
|
||||
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
cb(Qcur, "Qcur_reshaped", il);
|
||||
|
||||
// Apply Q normalization
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// Apply K normalization
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
// Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
|
||||
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
|
||||
cb(gate, "gate_reshaped", il);
|
||||
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// Apply RoPE
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
@@ -519,7 +470,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// Attention computation
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
cur = build_attn(inp,
|
||||
@@ -527,10 +477,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_pregate", il);
|
||||
|
||||
ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
|
||||
cb(gate_sigmoid, "gate_sigmoid", il);
|
||||
// TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont
|
||||
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, gate_sigmoid);
|
||||
gate = ggml_sigmoid(ctx0, gate);
|
||||
cb(gate, "gate_sigmoid", il);
|
||||
|
||||
gate = ggml_reshape_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, gate);
|
||||
cb(cur, "attn_gated", il);
|
||||
|
||||
cur = build_lora_mm(model.layers[il].wo, cur);
|
||||
@@ -560,7 +515,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_qkvz(
|
||||
cb(z, "z", il);
|
||||
|
||||
return { qkv_mixed, z };
|
||||
|
||||
} else {
|
||||
// legacy (slower) path
|
||||
ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input);
|
||||
@@ -624,9 +578,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_qkvz(
|
||||
ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il) {
|
||||
const auto * mctx_cur = inp->mctx;
|
||||
|
||||
@@ -671,7 +622,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
|
||||
cb(a, "a", il);
|
||||
|
||||
ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
|
||||
// TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont
|
||||
b = ggml_cont(ctx0, b);
|
||||
|
||||
ggml_tensor * beta = ggml_sigmoid(ctx0, b);
|
||||
|
||||
beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs);
|
||||
|
||||
// Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
|
||||
ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
|
||||
@@ -679,6 +635,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
|
||||
ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
|
||||
cb(alpha_softplus, "a_softplus", il);
|
||||
|
||||
ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus
|
||||
cb(gate, "gate", il);
|
||||
|
||||
@@ -686,8 +643,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
|
||||
// bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state();
|
||||
|
||||
// Build the convolution states tensor
|
||||
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||
cb(conv_states, "conv_states", il);
|
||||
@@ -696,11 +651,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
|
||||
const int64_t conv_kernel_size = conv_kernel->ne[0];
|
||||
const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
|
||||
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
|
||||
|
||||
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
|
||||
cb(conv_states, "conv_states_reshaped", il);
|
||||
|
||||
qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
|
||||
cb(qkv_mixed, "qkv_mixed_permuted", il);
|
||||
qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
|
||||
cb(qkv_mixed, "qkv_mixed_transposed", il);
|
||||
|
||||
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
|
||||
cb(conv_input, "conv_input", il);
|
||||
@@ -720,7 +676,10 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
|
||||
cb(conv_states_all, "conv_states_updated", il);
|
||||
|
||||
// Apply SSM convolution
|
||||
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
||||
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
|
||||
cb(state, "state_predelta", il);
|
||||
|
||||
ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
|
||||
cb(conv_output_proper, "conv_output_raw", il);
|
||||
|
||||
@@ -734,26 +693,36 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
|
||||
|
||||
// Extract the convolved Q, K, V from conv_output
|
||||
ggml_tensor * q_conv =
|
||||
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
|
||||
ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
|
||||
ggml_row_size(conv_qkv_mix->type, head_k_dim),
|
||||
nb1_qkv,
|
||||
nb1_qkv * n_seq_tokens,
|
||||
0);
|
||||
|
||||
ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
|
||||
ggml_row_size(conv_qkv_mix->type, head_k_dim),
|
||||
nb1_qkv,
|
||||
nb1_qkv * n_seq_tokens,
|
||||
head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
||||
|
||||
ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
|
||||
ggml_row_size(conv_qkv_mix->type, head_v_dim),
|
||||
nb1_qkv,
|
||||
nb1_qkv * n_seq_tokens,
|
||||
ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads));
|
||||
|
||||
cb(q_conv, "q_conv", il);
|
||||
ggml_tensor * k_conv =
|
||||
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
|
||||
head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
||||
cb(k_conv, "k_conv", il);
|
||||
ggml_tensor * v_conv =
|
||||
ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
|
||||
2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
||||
cb(v_conv, "v_conv", il);
|
||||
|
||||
// Unsqueeze them
|
||||
q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
||||
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
|
||||
cb(state, "state_predelta", il);
|
||||
q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
|
||||
k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
|
||||
|
||||
//q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
//k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// if head keys and value keys are different, repeat to force tensors into matching shapes
|
||||
if (num_k_heads != num_v_heads) {
|
||||
@@ -786,7 +755,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
if (n_seq_tokens == 1) {
|
||||
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
} else {
|
||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
|
||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
}
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
@@ -795,19 +764,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
|
||||
// Update the recurrent states
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0, new_state,
|
||||
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
|
||||
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
||||
|
||||
// Reshape both attn_out_final and z to 2D tensors for normalization
|
||||
// attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
||||
ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
|
||||
ggml_cpy(ctx0, new_state,
|
||||
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
|
||||
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
||||
|
||||
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
||||
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
|
||||
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// Apply gated normalization: self.norm(core_attn_out, z)
|
||||
ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
|
||||
ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il);
|
||||
|
||||
// Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
|
||||
ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
|
||||
@@ -818,7 +783,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
cb(cur, "linear_attn_out", il);
|
||||
|
||||
// Reshape back to original dimensions
|
||||
cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
@@ -839,7 +805,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
|
||||
if (model.layers[il].ffn_up_shexp != nullptr) {
|
||||
ggml_tensor * ffn_shexp =
|
||||
build_ffn(cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
@@ -852,11 +818,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
|
||||
ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
|
||||
cb(shared_gate, "shared_expert_gate", il);
|
||||
|
||||
// Apply sigmoid to the gate
|
||||
shared_gate = ggml_sigmoid(ctx0, shared_gate);
|
||||
cb(shared_gate, "shared_expert_gate_sigmoid", il);
|
||||
|
||||
// Apply the gate to the shared expert output
|
||||
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
|
||||
cb(ffn_shexp, "ffn_shexp_gated", il);
|
||||
|
||||
|
||||
@@ -8301,7 +8301,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
//for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
|
||||
for (int kv : { 113, 512, 1024, }) {
|
||||
if (nr2 != 1 && kv != 512) continue;
|
||||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (int nb : { 1, 3, 32, 75, }) {
|
||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
||||
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
|
||||
@@ -20,6 +20,7 @@ add_library(mtmd
|
||||
models/internvl.cpp
|
||||
models/kimivl.cpp
|
||||
models/kimik25.cpp
|
||||
models/nemotron-v2-vl.cpp
|
||||
models/llama4.cpp
|
||||
models/llava.cpp
|
||||
models/minicpmv.cpp
|
||||
|
||||
@@ -236,6 +236,7 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_GLM4V,
|
||||
PROJECTOR_TYPE_YOUTUVL,
|
||||
PROJECTOR_TYPE_KIMIK25,
|
||||
PROJECTOR_TYPE_NEMOTRON_V2_VL,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -270,6 +271,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
|
||||
{ PROJECTOR_TYPE_YOUTUVL, "youtuvl"},
|
||||
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
|
||||
{ PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
|
||||
@@ -15,6 +15,7 @@ enum ffn_op_type {
|
||||
FFN_GELU_ERF,
|
||||
FFN_SILU,
|
||||
FFN_GELU_QUICK,
|
||||
FFN_RELU_SQR,
|
||||
};
|
||||
|
||||
enum norm_type {
|
||||
|
||||
@@ -559,6 +559,12 @@ ggml_tensor * clip_graph::build_ffn(
|
||||
cur = ggml_gelu_quick(ctx0, cur);
|
||||
cb(cur, "ffn_gelu_quick", il);
|
||||
} break;
|
||||
case FFN_RELU_SQR:
|
||||
{
|
||||
cur = ggml_relu(ctx0, cur);
|
||||
cur = ggml_sqr(ctx0, cur);
|
||||
cb(cur, "ffn_relu_sqr", il);
|
||||
} break;
|
||||
}
|
||||
|
||||
if (down) {
|
||||
@@ -810,6 +816,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
{
|
||||
builder = std::make_unique<clip_graph_internvl>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_nemotron_v2_vl>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_llama4>(ctx, img);
|
||||
@@ -1110,6 +1120,7 @@ struct clip_model_loader {
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
{
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
|
||||
} break;
|
||||
@@ -1767,6 +1778,12 @@ struct clip_model_loader {
|
||||
model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
|
||||
model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
{
|
||||
model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
|
||||
model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
|
||||
model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
{
|
||||
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
@@ -3088,6 +3105,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
case PROJECTOR_TYPE_GLM_EDGE:
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_INTERNVL: // TODO @ngxson : support dynamic resolution
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
{
|
||||
clip_image_u8 resized_image;
|
||||
int sz = params.image_size;
|
||||
@@ -3397,6 +3415,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
{
|
||||
// both X and Y are downscaled by the scale factor
|
||||
@@ -3805,6 +3824,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
case PROJECTOR_TYPE_GEMMA3NV:
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
@@ -3968,6 +3988,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
return ctx->model.mm_3_w->ne[1];
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
return ctx->model.mm_model_proj->ne[1];
|
||||
|
||||
@@ -42,6 +42,11 @@ struct clip_graph_internvl : clip_graph {
|
||||
ggml_cgraph * build() override;
|
||||
};
|
||||
|
||||
struct clip_graph_nemotron_v2_vl : clip_graph {
|
||||
clip_graph_nemotron_v2_vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
};
|
||||
|
||||
struct clip_graph_llama4 : clip_graph {
|
||||
clip_graph_llama4(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
|
||||
35
tools/mtmd/models/nemotron-v2-vl.cpp
Normal file
35
tools/mtmd/models/nemotron-v2-vl.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
#include "models.h"
|
||||
|
||||
ggml_cgraph * clip_graph_nemotron_v2_vl::build() {
|
||||
GGML_ASSERT(model.class_embedding != nullptr);
|
||||
GGML_ASSERT(model.position_embeddings != nullptr);
|
||||
|
||||
const int n_registers = model.class_embedding->ne[1];
|
||||
const int n_pos = n_patches + n_registers;
|
||||
|
||||
ggml_tensor * inp = build_inp();
|
||||
|
||||
// add position embeddings (pre-downsampled during GGUF conversion for fixed 512x512 input)
|
||||
inp = ggml_add(ctx0, inp, model.position_embeddings);
|
||||
cb(inp, "inp_pos", -1);
|
||||
|
||||
inp = ggml_concat(ctx0, model.class_embedding, inp, 1);
|
||||
|
||||
ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, hparams.ffn_op, nullptr, nullptr);
|
||||
|
||||
cur = ggml_view_2d(ctx0, cur,
|
||||
n_embd, n_patches,
|
||||
ggml_row_size(cur->type, n_embd),
|
||||
n_registers * ggml_row_size(cur->type, n_embd));
|
||||
|
||||
cur = build_patch_merge_permute(cur, model.hparams.n_merge);
|
||||
|
||||
{
|
||||
cur = build_norm(cur, model.mm_0_w, nullptr, NORM_TYPE_RMS, 1e-6, -1);
|
||||
cur = build_ffn(cur, model.mm_1_w, nullptr, nullptr, nullptr, model.mm_3_w, nullptr, FFN_RELU_SQR, -1);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
@@ -132,7 +132,8 @@ static std::string fs_get_cache_directory() {
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \
|
||||
defined(__OpenBSD__) || defined(__NetBSD__)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else if (std::getenv("HOME")) {
|
||||
|
||||
@@ -28,10 +28,6 @@ target_link_libraries(${TARGET} PUBLIC common mtmd ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
set(TARGET llama-server)
|
||||
|
||||
if (NOT LLAMA_HTTPLIB)
|
||||
message(FATAL_ERROR "LLAMA_HTTPLIB is OFF, cannot build llama-server. Hint: to skip building server, set -DLLAMA_BUILD_SERVER=OFF")
|
||||
endif()
|
||||
|
||||
set(TARGET_SRCS
|
||||
server.cpp
|
||||
server-http.cpp
|
||||
|
||||
Binary file not shown.
@@ -1,17 +1,24 @@
|
||||
import type { StorybookConfig } from '@storybook/sveltekit';
|
||||
import { dirname, resolve } from 'path';
|
||||
import { fileURLToPath } from 'url';
|
||||
|
||||
const __dirname = dirname(fileURLToPath(import.meta.url));
|
||||
|
||||
const config: StorybookConfig = {
|
||||
stories: ['../tests/stories/**/*.mdx', '../tests/stories/**/*.stories.@(js|ts|svelte)'],
|
||||
addons: [
|
||||
'@storybook/addon-svelte-csf',
|
||||
'@chromatic-com/storybook',
|
||||
'@storybook/addon-docs',
|
||||
'@storybook/addon-vitest',
|
||||
'@storybook/addon-a11y',
|
||||
'@storybook/addon-vitest'
|
||||
'@storybook/addon-docs'
|
||||
],
|
||||
framework: {
|
||||
name: '@storybook/sveltekit',
|
||||
options: {}
|
||||
framework: '@storybook/sveltekit',
|
||||
viteFinal: async (config) => {
|
||||
config.server = config.server || {};
|
||||
config.server.fs = config.server.fs || {};
|
||||
config.server.fs.allow = [...(config.server.fs.allow || []), resolve(__dirname, '../tests')];
|
||||
return config;
|
||||
}
|
||||
};
|
||||
export default config;
|
||||
|
||||
@@ -13,7 +13,7 @@ const preview: Preview = {
|
||||
},
|
||||
|
||||
backgrounds: {
|
||||
disable: true
|
||||
disabled: true
|
||||
},
|
||||
|
||||
a11y: {
|
||||
|
||||
@@ -49,14 +49,20 @@ sequenceDiagram
|
||||
settingsStore->>serverStore: defaultParams
|
||||
serverStore-->>settingsStore: {temperature, top_p, top_k, ...}
|
||||
|
||||
settingsStore->>ParamSvc: extractServerDefaults(defaultParams)
|
||||
ParamSvc-->>settingsStore: Record<string, value>
|
||||
loop each SYNCABLE_PARAMETER
|
||||
alt key NOT in userOverrides
|
||||
settingsStore->>settingsStore: config[key] = serverDefault[key]
|
||||
Note right of settingsStore: Non-overridden params adopt server default
|
||||
else key in userOverrides
|
||||
Note right of settingsStore: Keep user value, skip server default
|
||||
end
|
||||
end
|
||||
|
||||
settingsStore->>ParamSvc: mergeWithServerDefaults(config, serverDefaults)
|
||||
Note right of ParamSvc: For each syncable parameter:<br/>- If NOT in userOverrides → use server default<br/>- If in userOverrides → keep user value
|
||||
ParamSvc-->>settingsStore: mergedConfig
|
||||
alt serverStore.props has webuiSettings
|
||||
settingsStore->>settingsStore: Apply webuiSettings from server
|
||||
Note right of settingsStore: Server-provided UI settings<br/>(e.g. showRawOutputSwitch)
|
||||
end
|
||||
|
||||
settingsStore->>settingsStore: config = mergedConfig
|
||||
settingsStore->>settingsStore: saveConfig()
|
||||
deactivate settingsStore
|
||||
|
||||
@@ -67,11 +73,18 @@ sequenceDiagram
|
||||
UI->>settingsStore: updateConfig(key, value)
|
||||
activate settingsStore
|
||||
settingsStore->>settingsStore: config[key] = value
|
||||
settingsStore->>settingsStore: userOverrides.add(key)
|
||||
Note right of settingsStore: Mark as user-modified (won't be overwritten by server)
|
||||
|
||||
alt value matches server default for key
|
||||
settingsStore->>settingsStore: userOverrides.delete(key)
|
||||
Note right of settingsStore: Matches server default, remove override
|
||||
else value differs from server default
|
||||
settingsStore->>settingsStore: userOverrides.add(key)
|
||||
Note right of settingsStore: Mark as user-modified (won't be overwritten)
|
||||
end
|
||||
|
||||
settingsStore->>settingsStore: saveConfig()
|
||||
settingsStore->>LS: set("llama-config", config)
|
||||
settingsStore->>LS: set("llama-userOverrides", [...userOverrides])
|
||||
settingsStore->>LS: set(CONFIG_LOCALSTORAGE_KEY, config)
|
||||
settingsStore->>LS: set(USER_OVERRIDES_LOCALSTORAGE_KEY, [...userOverrides])
|
||||
deactivate settingsStore
|
||||
|
||||
UI->>settingsStore: updateMultipleConfig({key1: val1, key2: val2})
|
||||
@@ -88,10 +101,9 @@ sequenceDiagram
|
||||
|
||||
UI->>settingsStore: resetConfig()
|
||||
activate settingsStore
|
||||
settingsStore->>settingsStore: config = SETTING_CONFIG_DEFAULT
|
||||
settingsStore->>settingsStore: config = {...SETTING_CONFIG_DEFAULT}
|
||||
settingsStore->>settingsStore: userOverrides.clear()
|
||||
settingsStore->>settingsStore: syncWithServerDefaults()
|
||||
Note right of settingsStore: Apply server defaults for syncable params
|
||||
Note right of settingsStore: All params reset to defaults<br/>Next syncWithServerDefaults will adopt server values
|
||||
settingsStore->>settingsStore: saveConfig()
|
||||
deactivate settingsStore
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
<script lang="ts">
|
||||
import { Eye } from '@lucide/svelte';
|
||||
import ActionIconCopyToClipboard from '$lib/components/app/actions/ActionIconCopyToClipboard.svelte';
|
||||
import { ActionIconCopyToClipboard } from '$lib/components/app';
|
||||
import { FileTypeText } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
|
||||
@@ -57,13 +57,13 @@
|
||||
let currentConfig = $derived(config());
|
||||
let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined);
|
||||
let isRecording = $state(false);
|
||||
let message = $state(initialMessage);
|
||||
let message = $derived(initialMessage);
|
||||
let pasteLongTextToFileLength = $derived.by(() => {
|
||||
const n = Number(currentConfig.pasteLongTextToFileLen);
|
||||
return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n;
|
||||
});
|
||||
let previousIsLoading = $state(isLoading);
|
||||
let previousInitialMessage = $state(initialMessage);
|
||||
let previousIsLoading = $derived(isLoading);
|
||||
let previousInitialMessage = $derived(initialMessage);
|
||||
let recordingSupported = $state(false);
|
||||
let textareaRef: ChatFormTextarea | undefined = $state(undefined);
|
||||
|
||||
@@ -289,7 +289,7 @@
|
||||
|
||||
<form
|
||||
onsubmit={handleSubmit}
|
||||
class="{INPUT_CLASSES} border-radius-bottom-none mx-auto max-w-[48rem] overflow-hidden rounded-3xl backdrop-blur-md {disabled
|
||||
class="relative {INPUT_CLASSES} border-radius-bottom-none mx-auto max-w-[48rem] overflow-hidden rounded-3xl backdrop-blur-md {disabled
|
||||
? 'cursor-not-allowed opacity-60'
|
||||
: ''} {className}"
|
||||
data-slot="chat-form"
|
||||
@@ -304,10 +304,11 @@
|
||||
/>
|
||||
|
||||
<div
|
||||
class="flex-column relative min-h-[48px] items-center rounded-3xl px-5 py-3 shadow-sm transition-all focus-within:shadow-md"
|
||||
class="flex-column relative min-h-[48px] items-center rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:!py-3"
|
||||
onpaste={handlePaste}
|
||||
>
|
||||
<ChatFormTextarea
|
||||
class="px-5 py-1.5 md:pt-0"
|
||||
bind:this={textareaRef}
|
||||
bind:value={message}
|
||||
onKeydown={handleKeydown}
|
||||
@@ -315,6 +316,7 @@
|
||||
/>
|
||||
|
||||
<ChatFormActions
|
||||
class="px-3"
|
||||
bind:this={chatFormActionsRef}
|
||||
canSend={message.trim().length > 0 || uploadedFiles.length > 0}
|
||||
hasText={message.trim().length > 0}
|
||||
|
||||
@@ -0,0 +1,189 @@
|
||||
<script lang="ts">
|
||||
import { page } from '$app/state';
|
||||
import { MessageSquare, Plus } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { FILE_TYPE_ICONS } from '$lib/constants/icons';
|
||||
import { TOOLTIP_DELAY_DURATION } from '$lib/constants/tooltip-config';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
hasAudioModality?: boolean;
|
||||
hasVisionModality?: boolean;
|
||||
onFileUpload?: () => void;
|
||||
onSystemPromptClick?: () => void;
|
||||
}
|
||||
|
||||
type AttachmentActionId = 'images' | 'audio' | 'text' | 'pdf' | 'system';
|
||||
|
||||
interface AttachmentAction {
|
||||
id: AttachmentActionId;
|
||||
label: string;
|
||||
disabled?: boolean;
|
||||
disabledReason?: string;
|
||||
tooltip?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
hasAudioModality = false,
|
||||
hasVisionModality = false,
|
||||
onFileUpload,
|
||||
onSystemPromptClick
|
||||
}: Props = $props();
|
||||
|
||||
let isNewChat = $derived(!page.params.id);
|
||||
let systemMessageTooltip = $derived(
|
||||
isNewChat
|
||||
? 'Add custom system message for a new conversation'
|
||||
: 'Inject custom system message at the beginning of the conversation'
|
||||
);
|
||||
|
||||
let actions = $derived.by<AttachmentAction[]>(() => [
|
||||
{
|
||||
id: 'images',
|
||||
label: 'Images',
|
||||
disabled: !hasVisionModality,
|
||||
disabledReason: !hasVisionModality
|
||||
? 'Images require vision models to be processed'
|
||||
: undefined
|
||||
},
|
||||
{
|
||||
id: 'audio',
|
||||
label: 'Audio Files',
|
||||
disabled: !hasAudioModality,
|
||||
disabledReason: !hasAudioModality
|
||||
? 'Audio files require audio models to be processed'
|
||||
: undefined
|
||||
},
|
||||
{
|
||||
id: 'text',
|
||||
label: 'Text Files'
|
||||
},
|
||||
{
|
||||
id: 'pdf',
|
||||
label: 'PDF Files',
|
||||
tooltip: !hasVisionModality
|
||||
? 'PDFs will be converted to text. Image-based PDFs may not work properly.'
|
||||
: undefined
|
||||
},
|
||||
{
|
||||
id: 'system',
|
||||
label: 'System Message',
|
||||
tooltip: systemMessageTooltip
|
||||
}
|
||||
]);
|
||||
|
||||
function handleActionClick(id: AttachmentActionId) {
|
||||
if (id === 'system') {
|
||||
onSystemPromptClick?.();
|
||||
return;
|
||||
}
|
||||
|
||||
onFileUpload?.();
|
||||
}
|
||||
|
||||
const triggerTooltipText = 'Add files or system message';
|
||||
const itemClass = 'flex cursor-pointer items-center gap-2';
|
||||
</script>
|
||||
|
||||
<div class="flex items-center gap-1 {className}">
|
||||
<DropdownMenu.Root>
|
||||
<DropdownMenu.Trigger name="Attach files" {disabled}>
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<Button
|
||||
class="file-upload-button h-8 w-8 rounded-full p-0"
|
||||
{disabled}
|
||||
variant="secondary"
|
||||
type="button"
|
||||
>
|
||||
<span class="sr-only">{triggerTooltipText}</span>
|
||||
|
||||
<Plus class="h-4 w-4" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>{triggerTooltipText}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
</DropdownMenu.Trigger>
|
||||
|
||||
<DropdownMenu.Content align="start" class="w-56">
|
||||
{#each actions as item (item.id)}
|
||||
{@const hasDisabledTooltip = !!item.disabled && !!item.disabledReason}
|
||||
{@const hasEnabledTooltip = !item.disabled && !!item.tooltip}
|
||||
|
||||
{#if hasDisabledTooltip}
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item class={itemClass} disabled>
|
||||
{#if item.id === 'images'}
|
||||
<FILE_TYPE_ICONS.image class="h-4 w-4" />
|
||||
{:else if item.id === 'audio'}
|
||||
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
|
||||
{:else if item.id === 'text'}
|
||||
<FILE_TYPE_ICONS.text class="h-4 w-4" />
|
||||
{:else if item.id === 'pdf'}
|
||||
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
|
||||
{:else}
|
||||
<MessageSquare class="h-4 w-4" />
|
||||
{/if}
|
||||
|
||||
<span>{item.label}</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>{item.disabledReason}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else if hasEnabledTooltip}
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item class={itemClass} onclick={() => handleActionClick(item.id)}>
|
||||
{#if item.id === 'images'}
|
||||
<FILE_TYPE_ICONS.image class="h-4 w-4" />
|
||||
{:else if item.id === 'audio'}
|
||||
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
|
||||
{:else if item.id === 'text'}
|
||||
<FILE_TYPE_ICONS.text class="h-4 w-4" />
|
||||
{:else if item.id === 'pdf'}
|
||||
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
|
||||
{:else}
|
||||
<MessageSquare class="h-4 w-4" />
|
||||
{/if}
|
||||
|
||||
<span>{item.label}</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>{item.tooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
<DropdownMenu.Item class={itemClass} onclick={() => handleActionClick(item.id)}>
|
||||
{#if item.id === 'images'}
|
||||
<FILE_TYPE_ICONS.image class="h-4 w-4" />
|
||||
{:else if item.id === 'audio'}
|
||||
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
|
||||
{:else if item.id === 'text'}
|
||||
<FILE_TYPE_ICONS.text class="h-4 w-4" />
|
||||
{:else if item.id === 'pdf'}
|
||||
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
|
||||
{:else}
|
||||
<MessageSquare class="h-4 w-4" />
|
||||
{/if}
|
||||
|
||||
<span>{item.label}</span>
|
||||
</DropdownMenu.Item>
|
||||
{/if}
|
||||
{/each}
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
</div>
|
||||
@@ -2,7 +2,7 @@
|
||||
import { Square } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import {
|
||||
ChatFormActionFileAttachments,
|
||||
ChatFormActionAttachmentsDropdown,
|
||||
ChatFormActionRecord,
|
||||
ChatFormActionSubmit,
|
||||
ModelsSelector
|
||||
@@ -157,7 +157,7 @@
|
||||
|
||||
const { handleModelChange } = useModelChangeValidation({
|
||||
getRequiredModalities: () => usedModalities(),
|
||||
onValidationFailure: async (previousModelId) => {
|
||||
onValidationFailure: async (previousModelId: string | null) => {
|
||||
if (previousModelId) {
|
||||
await modelsStore.selectModelById(previousModelId);
|
||||
}
|
||||
@@ -166,32 +166,39 @@
|
||||
</script>
|
||||
|
||||
<div class="flex w-full items-center gap-3 {className}" style="container-type: inline-size">
|
||||
<ChatFormActionFileAttachments
|
||||
class="mr-auto"
|
||||
{disabled}
|
||||
{hasAudioModality}
|
||||
{hasVisionModality}
|
||||
{onFileUpload}
|
||||
{onSystemPromptClick}
|
||||
/>
|
||||
<div class="mr-auto flex items-center gap-2">
|
||||
<ChatFormActionAttachmentsDropdown
|
||||
{disabled}
|
||||
{hasAudioModality}
|
||||
{hasVisionModality}
|
||||
{onFileUpload}
|
||||
{onSystemPromptClick}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<ModelsSelector
|
||||
{disabled}
|
||||
bind:this={selectorModelRef}
|
||||
currentModel={conversationModel}
|
||||
forceForegroundText={true}
|
||||
useGlobalSelection={true}
|
||||
onModelChange={handleModelChange}
|
||||
/>
|
||||
<div class="ml-auto flex items-center gap-1.5">
|
||||
<ModelsSelector
|
||||
{disabled}
|
||||
bind:this={selectorModelRef}
|
||||
currentModel={conversationModel}
|
||||
forceForegroundText={true}
|
||||
useGlobalSelection={true}
|
||||
onModelChange={handleModelChange}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{#if isLoading}
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
onclick={onStop}
|
||||
class="h-8 w-8 bg-transparent p-0 hover:bg-destructive/20"
|
||||
class="group h-8 w-8 rounded-full p-0 hover:bg-destructive/10!"
|
||||
>
|
||||
<span class="sr-only">Stop</span>
|
||||
<Square class="h-8 w-8 fill-destructive stroke-destructive" />
|
||||
|
||||
<Square
|
||||
class="h-8 w-8 fill-muted-foreground stroke-muted-foreground group-hover:fill-destructive group-hover:stroke-destructive hover:fill-destructive hover:stroke-destructive"
|
||||
/>
|
||||
</Button>
|
||||
{:else if shouldShowRecordButton}
|
||||
<ChatFormActionRecord {disabled} {hasAudioModality} {isLoading} {isRecording} {onMicClick} />
|
||||
|
||||
@@ -62,8 +62,8 @@
|
||||
assistantMessages: number;
|
||||
messageTypes: string[];
|
||||
} | null>(null);
|
||||
let editedContent = $state(message.content);
|
||||
let editedExtras = $state<DatabaseMessageExtra[]>(message.extra ? [...message.extra] : []);
|
||||
let editedContent = $derived(message.content);
|
||||
let editedExtras = $derived<DatabaseMessageExtra[]>(message.extra ? [...message.extra] : []);
|
||||
let editedUploadedFiles = $state<ChatUploadedFile[]>([]);
|
||||
let isEditing = $state(false);
|
||||
let showDeleteDialog = $state(false);
|
||||
|
||||
@@ -105,7 +105,7 @@
|
||||
|
||||
const { handleModelChange } = useModelChangeValidation({
|
||||
getRequiredModalities: () => conversationsStore.getModalitiesUpToMessage(message.id),
|
||||
onSuccess: (modelName) => onRegenerate(modelName)
|
||||
onSuccess: (modelName: string) => onRegenerate(modelName)
|
||||
});
|
||||
|
||||
function handleCopyModel() {
|
||||
|
||||
@@ -133,7 +133,7 @@
|
||||
|
||||
const { handleModelChange } = useModelChangeValidation({
|
||||
getRequiredModalities,
|
||||
onValidationFailure: async (previousModelId) => {
|
||||
onValidationFailure: async (previousModelId: string | null) => {
|
||||
if (previousModelId) {
|
||||
await modelsStore.selectModelById(previousModelId);
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
initialView = ChatMessageStatsView.GENERATION
|
||||
}: Props = $props();
|
||||
|
||||
let activeView: ChatMessageStatsView = $state(initialView);
|
||||
let activeView: ChatMessageStatsView = $derived(initialView);
|
||||
let hasAutoSwitchedToGeneration = $state(false);
|
||||
|
||||
// In live mode: auto-switch to GENERATION tab when prompt processing completes
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
|
||||
import { isFileTypeSupported, filterFilesByModalities } from '$lib/utils';
|
||||
import { parseFilesToMessageExtras, processFilesToChatUploaded } from '$lib/utils/browser-only';
|
||||
import { ErrorDialogType } from '$lib/enums';
|
||||
import { onMount } from 'svelte';
|
||||
import { fade, fly, slide } from 'svelte/transition';
|
||||
import { Trash2, AlertTriangle, RefreshCw } from '@lucide/svelte';
|
||||
@@ -616,7 +617,7 @@
|
||||
contextInfo={activeErrorDialog?.contextInfo}
|
||||
onOpenChange={handleErrorDialogOpenChange}
|
||||
open={Boolean(activeErrorDialog)}
|
||||
type={activeErrorDialog?.type ?? 'server'}
|
||||
type={(activeErrorDialog?.type as ErrorDialogType) ?? ErrorDialogType.SERVER}
|
||||
/>
|
||||
|
||||
<style>
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
<script lang="ts">
|
||||
import ChatForm from '$lib/components/app/chat/ChatForm/ChatForm.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
initialMessage?: string;
|
||||
isLoading?: boolean;
|
||||
onFileRemove?: (fileId: string) => void;
|
||||
onFileUpload?: (files: File[]) => void;
|
||||
onSend?: (message: string, files?: ChatUploadedFile[]) => Promise<boolean>;
|
||||
onStop?: () => void;
|
||||
onSystemPromptAdd?: (draft: { message: string; files: ChatUploadedFile[] }) => void;
|
||||
showHelperText?: boolean;
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
}
|
||||
|
||||
let {
|
||||
class: className,
|
||||
disabled = false,
|
||||
initialMessage = '',
|
||||
isLoading = false,
|
||||
onFileRemove,
|
||||
onFileUpload,
|
||||
onSend,
|
||||
onStop,
|
||||
onSystemPromptAdd,
|
||||
showHelperText = true,
|
||||
uploadedFiles = $bindable([])
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="relative mx-auto max-w-[48rem]">
|
||||
<ChatForm
|
||||
class={className}
|
||||
{disabled}
|
||||
{initialMessage}
|
||||
{isLoading}
|
||||
{onFileRemove}
|
||||
{onFileUpload}
|
||||
{onSend}
|
||||
{onStop}
|
||||
{onSystemPromptAdd}
|
||||
{showHelperText}
|
||||
bind:uploadedFiles
|
||||
/>
|
||||
</div>
|
||||
@@ -18,19 +18,24 @@
|
||||
} from '$lib/components/app';
|
||||
import { ScrollArea } from '$lib/components/ui/scroll-area';
|
||||
import { config, settingsStore } from '$lib/stores/settings.svelte';
|
||||
import {
|
||||
SETTINGS_SECTION_TITLES,
|
||||
type SettingsSectionTitle
|
||||
} from '$lib/constants/settings-sections';
|
||||
import { setMode } from 'mode-watcher';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
onSave?: () => void;
|
||||
initialSection?: SettingsSectionTitle;
|
||||
}
|
||||
|
||||
let { onSave }: Props = $props();
|
||||
let { onSave, initialSection }: Props = $props();
|
||||
|
||||
const settingSections: Array<{
|
||||
fields: SettingsFieldConfig[];
|
||||
icon: Component;
|
||||
title: string;
|
||||
title: SettingsSectionTitle;
|
||||
}> = [
|
||||
{
|
||||
title: 'General',
|
||||
@@ -285,7 +290,9 @@
|
||||
// }
|
||||
];
|
||||
|
||||
let activeSection = $state('General');
|
||||
let activeSection = $derived<SettingsSectionTitle>(
|
||||
initialSection ?? SETTINGS_SECTION_TITLES.GENERAL
|
||||
);
|
||||
let currentSection = $derived(
|
||||
settingSections.find((section) => section.title === activeSection) || settingSections[0]
|
||||
);
|
||||
@@ -295,6 +302,16 @@
|
||||
let canScrollRight = $state(false);
|
||||
let scrollContainer: HTMLDivElement | undefined = $state();
|
||||
|
||||
$effect(() => {
|
||||
if (!initialSection) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (settingSections.some((section) => section.title === initialSection)) {
|
||||
activeSection = initialSection;
|
||||
}
|
||||
});
|
||||
|
||||
function handleThemeChange(newTheme: string) {
|
||||
localConfig.theme = newTheme;
|
||||
|
||||
|
||||
@@ -142,7 +142,7 @@
|
||||
{
|
||||
icon: Download,
|
||||
label: 'Export',
|
||||
onclick: (e) => {
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
conversationsStore.downloadConversation(conversation.id);
|
||||
},
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer';
|
||||
import { rehypeEnhanceLinks } from '$lib/markdown/enhance-links';
|
||||
import { rehypeEnhanceCodeBlocks } from '$lib/markdown/enhance-code-blocks';
|
||||
import { rehypeResolveAttachmentImages } from '$lib/markdown/resolve-attachment-images';
|
||||
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
|
||||
import { copyCodeToClipboard, preprocessLaTeX, getImageErrorFallbackHtml } from '$lib/utils';
|
||||
import {
|
||||
@@ -23,6 +24,7 @@
|
||||
DATA_ERROR_HANDLED_ATTR,
|
||||
BOOL_TRUE_STRING
|
||||
} from '$lib/constants/markdown';
|
||||
import { UrlPrefix } from '$lib/enums';
|
||||
import { FileTypeText } from '$lib/enums/files';
|
||||
import {
|
||||
highlightCode,
|
||||
@@ -33,8 +35,7 @@
|
||||
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
|
||||
import githubLightCss from 'highlight.js/styles/github.css?inline';
|
||||
import { mode } from 'mode-watcher';
|
||||
import ActionIconsCodeBlock from '$lib/components/app/actions/ActionIconsCodeBlock.svelte';
|
||||
import DialogCodePreview from '$lib/components/app/misc/CodePreviewDialog.svelte';
|
||||
import { ActionIconsCodeBlock, DialogCodePreview } from '$lib/components/app';
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
import type { DatabaseMessageExtra } from '$lib/types/database';
|
||||
|
||||
@@ -100,6 +101,7 @@
|
||||
.use(rehypeRestoreTableHtml) // Restore limited HTML (e.g., <br>, <ul>) inside Markdown tables
|
||||
.use(rehypeEnhanceLinks) // Add target="_blank" to links
|
||||
.use(rehypeEnhanceCodeBlocks) // Wrap code blocks with header and actions
|
||||
.use(rehypeResolveAttachmentImages, { attachments })
|
||||
.use(rehypeStringify, { allowDangerousHtml: true }); // Convert to HTML string
|
||||
});
|
||||
|
||||
@@ -500,7 +502,10 @@
|
||||
if (!img || !img.src) return;
|
||||
|
||||
// Don't handle data URLs or already-handled images
|
||||
if (img.src.startsWith('data:') || img.dataset[DATA_ERROR_HANDLED_ATTR] === BOOL_TRUE_STRING)
|
||||
if (
|
||||
img.src.startsWith(UrlPrefix.DATA) ||
|
||||
img.dataset[DATA_ERROR_HANDLED_ATTR] === BOOL_TRUE_STRING
|
||||
)
|
||||
return;
|
||||
img.dataset[DATA_ERROR_HANDLED_ATTR] = BOOL_TRUE_STRING;
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
<script lang="ts">
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import { AlertTriangle, TimerOff } from '@lucide/svelte';
|
||||
import { ErrorDialogType } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
open: boolean;
|
||||
type: 'timeout' | 'server';
|
||||
type: ErrorDialogType;
|
||||
message: string;
|
||||
contextInfo?: { n_prompt_tokens: number; n_ctx: number };
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
@@ -12,7 +13,7 @@
|
||||
|
||||
let { open = $bindable(), type, message, contextInfo, onOpenChange }: Props = $props();
|
||||
|
||||
const isTimeout = $derived(type === 'timeout');
|
||||
const isTimeout = $derived(type === ErrorDialogType.TIMEOUT);
|
||||
const title = $derived(isTimeout ? 'TCP Timeout' : 'Server Error');
|
||||
const description = $derived(
|
||||
isTimeout
|
||||
@@ -58,7 +59,12 @@
|
||||
<span class="font-medium">Prompt tokens:</span>
|
||||
{contextInfo.n_prompt_tokens.toLocaleString()}
|
||||
</p>
|
||||
<p><span class="font-medium">Context size:</span> {contextInfo.n_ctx.toLocaleString()}</p>
|
||||
{#if contextInfo.n_ctx}
|
||||
<p>
|
||||
<span class="font-medium">Context size:</span>
|
||||
{contextInfo.n_ctx.toLocaleString()}
|
||||
</p>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
<script lang="ts">
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { ChatSettings } from '$lib/components/app';
|
||||
import type { SettingsSectionTitle } from '$lib/constants/settings-sections';
|
||||
|
||||
interface Props {
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
open?: boolean;
|
||||
initialSection?: SettingsSectionTitle;
|
||||
}
|
||||
|
||||
let { onOpenChange, open = false }: Props = $props();
|
||||
let { onOpenChange, open = false, initialSection }: Props = $props();
|
||||
|
||||
let chatSettingsRef: ChatSettings | undefined = $state();
|
||||
|
||||
@@ -28,10 +30,9 @@
|
||||
|
||||
<Dialog.Root {open} onOpenChange={handleClose}>
|
||||
<Dialog.Content
|
||||
class="z-999999 flex h-[100dvh] max-h-[100dvh] min-h-[100dvh] flex-col gap-0 rounded-none p-0
|
||||
md:h-[64vh] md:max-h-[64vh] md:min-h-0 md:rounded-lg"
|
||||
style="max-width: 48rem;"
|
||||
class="z-999999 flex h-[100dvh] max-h-[100dvh] min-h-[100dvh] max-w-4xl! flex-col gap-0 rounded-none
|
||||
p-0 md:h-[64vh] md:max-h-[64vh] md:min-h-0 md:rounded-lg"
|
||||
>
|
||||
<ChatSettings bind:this={chatSettingsRef} onSave={handleSave} />
|
||||
<ChatSettings bind:this={chatSettingsRef} onSave={handleSave} {initialSection} />
|
||||
</Dialog.Content>
|
||||
</Dialog.Root>
|
||||
|
||||
@@ -37,7 +37,7 @@
|
||||
<iframe
|
||||
bind:this={iframeRef}
|
||||
title="Preview {language}"
|
||||
sandbox="allow-scripts"
|
||||
sandbox="allow-scripts allow-same-origin"
|
||||
class="code-preview-iframe"
|
||||
></iframe>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
<script lang="ts">
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import type { Component } from 'svelte';
|
||||
import { KeyboardKey } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
open: boolean;
|
||||
@@ -29,7 +30,7 @@
|
||||
}: Props = $props();
|
||||
|
||||
function handleKeydown(event: KeyboardEvent) {
|
||||
if (event.key === 'Enter') {
|
||||
if (event.key === KeyboardKey.ENTER) {
|
||||
event.preventDefault();
|
||||
onConfirm();
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<script lang="ts">
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import * as Table from '$lib/components/ui/table';
|
||||
import { BadgeModality, CopyToClipboardIcon } from '$lib/components/app';
|
||||
import { BadgeModality, ActionIconCopyToClipboard } from '$lib/components/app';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
import { modelsStore, modelOptions, modelsLoading } from '$lib/stores/models.svelte';
|
||||
import { formatFileSize, formatParameters, formatNumber } from '$lib/utils';
|
||||
@@ -47,6 +47,7 @@
|
||||
|
||||
<Dialog.Header>
|
||||
<Dialog.Title>Model Information</Dialog.Title>
|
||||
|
||||
<Dialog.Description>Current model details and capabilities</Dialog.Description>
|
||||
</Dialog.Header>
|
||||
|
||||
@@ -73,7 +74,7 @@
|
||||
{modelName}
|
||||
</span>
|
||||
|
||||
<CopyToClipboardIcon
|
||||
<ActionIconCopyToClipboard
|
||||
text={modelName || ''}
|
||||
canCopy={!!modelName}
|
||||
ariaLabel="Copy model name to clipboard"
|
||||
@@ -97,7 +98,7 @@
|
||||
{serverProps.model_path}
|
||||
</span>
|
||||
|
||||
<CopyToClipboardIcon
|
||||
<ActionIconCopyToClipboard
|
||||
text={serverProps.model_path}
|
||||
ariaLabel="Copy model path to clipboard"
|
||||
/>
|
||||
@@ -105,17 +106,29 @@
|
||||
</Table.Row>
|
||||
|
||||
<!-- Context Size -->
|
||||
<Table.Row>
|
||||
<Table.Cell class="h-10 align-middle font-medium">Context Size</Table.Cell>
|
||||
<Table.Cell
|
||||
>{formatNumber(serverProps.default_generation_settings.n_ctx)} tokens</Table.Cell
|
||||
>
|
||||
</Table.Row>
|
||||
{#if serverProps?.default_generation_settings?.n_ctx}
|
||||
<Table.Row>
|
||||
<Table.Cell class="h-10 align-middle font-medium">Context Size</Table.Cell>
|
||||
|
||||
<Table.Cell
|
||||
>{formatNumber(serverProps.default_generation_settings.n_ctx)} tokens</Table.Cell
|
||||
>
|
||||
</Table.Row>
|
||||
{:else}
|
||||
<Table.Row>
|
||||
<Table.Cell class="h-10 align-middle font-medium text-red-500"
|
||||
>Context Size</Table.Cell
|
||||
>
|
||||
|
||||
<Table.Cell class="text-red-500">Not available</Table.Cell>
|
||||
</Table.Row>
|
||||
{/if}
|
||||
|
||||
<!-- Training Context -->
|
||||
{#if modelMeta?.n_ctx_train}
|
||||
<Table.Row>
|
||||
<Table.Cell class="h-10 align-middle font-medium">Training Context</Table.Cell>
|
||||
|
||||
<Table.Cell>{formatNumber(modelMeta.n_ctx_train)} tokens</Table.Cell>
|
||||
</Table.Row>
|
||||
{/if}
|
||||
@@ -124,6 +137,7 @@
|
||||
{#if modelMeta?.size}
|
||||
<Table.Row>
|
||||
<Table.Cell class="h-10 align-middle font-medium">Model Size</Table.Cell>
|
||||
|
||||
<Table.Cell>{formatFileSize(modelMeta.size)}</Table.Cell>
|
||||
</Table.Row>
|
||||
{/if}
|
||||
@@ -132,6 +146,7 @@
|
||||
{#if modelMeta?.n_params}
|
||||
<Table.Row>
|
||||
<Table.Cell class="h-10 align-middle font-medium">Parameters</Table.Cell>
|
||||
|
||||
<Table.Cell>{formatParameters(modelMeta.n_params)}</Table.Cell>
|
||||
</Table.Row>
|
||||
{/if}
|
||||
@@ -140,6 +155,7 @@
|
||||
{#if modelMeta?.n_embd}
|
||||
<Table.Row>
|
||||
<Table.Cell class="align-middle font-medium">Embedding Size</Table.Cell>
|
||||
|
||||
<Table.Cell>{formatNumber(modelMeta.n_embd)}</Table.Cell>
|
||||
</Table.Row>
|
||||
{/if}
|
||||
@@ -148,6 +164,7 @@
|
||||
{#if modelMeta?.n_vocab}
|
||||
<Table.Row>
|
||||
<Table.Cell class="align-middle font-medium">Vocabulary Size</Table.Cell>
|
||||
|
||||
<Table.Cell>{formatNumber(modelMeta.n_vocab)} tokens</Table.Cell>
|
||||
</Table.Row>
|
||||
{/if}
|
||||
@@ -163,6 +180,7 @@
|
||||
<!-- Total Slots -->
|
||||
<Table.Row>
|
||||
<Table.Cell class="align-middle font-medium">Parallel Slots</Table.Cell>
|
||||
|
||||
<Table.Cell>{serverProps.total_slots}</Table.Cell>
|
||||
</Table.Row>
|
||||
|
||||
@@ -170,6 +188,7 @@
|
||||
{#if modalities.length > 0}
|
||||
<Table.Row>
|
||||
<Table.Cell class="align-middle font-medium">Modalities</Table.Cell>
|
||||
|
||||
<Table.Cell>
|
||||
<div class="flex flex-wrap gap-1">
|
||||
<BadgeModality {modalities} />
|
||||
@@ -181,6 +200,7 @@
|
||||
<!-- Build Info -->
|
||||
<Table.Row>
|
||||
<Table.Cell class="align-middle font-medium">Build Info</Table.Cell>
|
||||
|
||||
<Table.Cell class="align-middle font-mono text-xs"
|
||||
>{serverProps.build_info}</Table.Cell
|
||||
>
|
||||
@@ -190,6 +210,7 @@
|
||||
{#if serverProps.chat_template}
|
||||
<Table.Row>
|
||||
<Table.Cell class="align-middle font-medium">Chat Template</Table.Cell>
|
||||
|
||||
<Table.Cell class="py-10">
|
||||
<div class="max-h-120 overflow-y-auto rounded-md bg-muted p-4">
|
||||
<pre
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
<script lang="ts">
|
||||
import { Plus, Trash2 } from '@lucide/svelte';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import { autoResizeTextarea } from '$lib/utils';
|
||||
import type { KeyValuePair } from '$lib/types';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
pairs: KeyValuePair[];
|
||||
onPairsChange: (pairs: KeyValuePair[]) => void;
|
||||
keyPlaceholder?: string;
|
||||
valuePlaceholder?: string;
|
||||
addButtonLabel?: string;
|
||||
emptyMessage?: string;
|
||||
sectionLabel?: string;
|
||||
sectionLabelOptional?: boolean;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
pairs,
|
||||
onPairsChange,
|
||||
keyPlaceholder = 'Key',
|
||||
valuePlaceholder = 'Value',
|
||||
addButtonLabel = 'Add',
|
||||
emptyMessage = 'No items configured.',
|
||||
sectionLabel,
|
||||
sectionLabelOptional = true
|
||||
}: Props = $props();
|
||||
|
||||
function addPair() {
|
||||
onPairsChange([...pairs, { key: '', value: '' }]);
|
||||
}
|
||||
|
||||
function removePair(index: number) {
|
||||
onPairsChange(pairs.filter((_, i) => i !== index));
|
||||
}
|
||||
|
||||
function updatePairKey(index: number, key: string) {
|
||||
const newPairs = [...pairs];
|
||||
newPairs[index] = { ...newPairs[index], key };
|
||||
onPairsChange(newPairs);
|
||||
}
|
||||
|
||||
function updatePairValue(index: number, value: string) {
|
||||
const newPairs = [...pairs];
|
||||
newPairs[index] = { ...newPairs[index], value };
|
||||
onPairsChange(newPairs);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class={className}>
|
||||
<div class="mb-2 flex items-center justify-between">
|
||||
{#if sectionLabel}
|
||||
<span class="text-xs font-medium">
|
||||
{sectionLabel}
|
||||
{#if sectionLabelOptional}
|
||||
<span class="text-muted-foreground">(optional)</span>
|
||||
{/if}
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
<button
|
||||
type="button"
|
||||
class="inline-flex cursor-pointer items-center gap-1 rounded-md px-1.5 py-1 text-xs text-muted-foreground hover:bg-muted hover:text-foreground"
|
||||
onclick={addPair}
|
||||
>
|
||||
<Plus class="h-3 w-3" />
|
||||
{addButtonLabel}
|
||||
</button>
|
||||
</div>
|
||||
{#if pairs.length > 0}
|
||||
<div class="space-y-3">
|
||||
{#each pairs as pair, index (index)}
|
||||
<div class="flex items-start gap-2">
|
||||
<Input
|
||||
type="text"
|
||||
placeholder={keyPlaceholder}
|
||||
value={pair.key}
|
||||
oninput={(e) => updatePairKey(index, e.currentTarget.value)}
|
||||
class="flex-1"
|
||||
/>
|
||||
|
||||
<textarea
|
||||
use:autoResizeTextarea
|
||||
placeholder={valuePlaceholder}
|
||||
value={pair.value}
|
||||
oninput={(e) => {
|
||||
updatePairValue(index, e.currentTarget.value);
|
||||
autoResizeTextarea(e.currentTarget);
|
||||
}}
|
||||
class="flex-1 resize-none rounded-md border border-input bg-transparent px-3 py-2 text-sm leading-5 placeholder:text-muted-foreground focus-visible:ring-1 focus-visible:ring-ring focus-visible:outline-none"
|
||||
rows="1"
|
||||
></textarea>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
class="mt-1.5 shrink-0 cursor-pointer rounded-md p-1 text-muted-foreground hover:bg-destructive/10 hover:text-destructive"
|
||||
onclick={() => removePair(index)}
|
||||
aria-label="Remove item"
|
||||
>
|
||||
<Trash2 class="h-3.5 w-3.5" />
|
||||
</button>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{:else}
|
||||
<p class="text-xs text-muted-foreground">{emptyMessage}</p>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -46,7 +46,7 @@
|
||||
|
||||
<div class="relative {className}">
|
||||
<Search
|
||||
class="absolute top-1/2 left-3 h-4 w-4 -translate-y-1/2 transform text-muted-foreground"
|
||||
class="absolute top-1/2 left-3 z-10 h-4 w-4 -translate-y-1/2 transform text-muted-foreground"
|
||||
/>
|
||||
|
||||
<Input
|
||||
30
tools/server/webui/src/lib/components/app/forms/index.ts
Normal file
30
tools/server/webui/src/lib/components/app/forms/index.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
/**
|
||||
*
|
||||
* FORMS & INPUTS
|
||||
*
|
||||
* Form-related utility components.
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* **SearchInput** - Search field with clear button
|
||||
*
|
||||
* Input field optimized for search with clear button and keyboard handling.
|
||||
* Supports placeholder, autofocus, and change callbacks.
|
||||
*/
|
||||
export { default as SearchInput } from './SearchInput.svelte';
|
||||
|
||||
/**
|
||||
* **KeyValuePairs** - Editable key-value list
|
||||
*
|
||||
* Dynamic list of key-value pairs with add/remove functionality.
|
||||
* Used for HTTP headers, metadata, and configuration.
|
||||
*
|
||||
* **Features:**
|
||||
* - Add new pairs with button
|
||||
* - Remove individual pairs
|
||||
* - Customizable placeholders and labels
|
||||
* - Empty state message
|
||||
* - Auto-resize value textarea
|
||||
*/
|
||||
export { default as KeyValuePairs } from './KeyValuePairs.svelte';
|
||||
@@ -1,12 +1,20 @@
|
||||
// Chat
|
||||
export * from './actions';
|
||||
export * from './badges';
|
||||
export * from './content';
|
||||
export * from './forms';
|
||||
export * from './misc';
|
||||
export * from './models';
|
||||
export * from './navigation';
|
||||
export * from './server';
|
||||
|
||||
// Chat
|
||||
export { default as ChatAttachmentPreview } from './chat/ChatAttachments/ChatAttachmentPreview.svelte';
|
||||
export { default as ChatAttachmentThumbnailFile } from './chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte';
|
||||
export { default as ChatAttachmentThumbnailImage } from './chat/ChatAttachments/ChatAttachmentThumbnailImage.svelte';
|
||||
export { default as ChatAttachmentsList } from './chat/ChatAttachments/ChatAttachmentsList.svelte';
|
||||
export { default as ChatAttachmentsViewAll } from './chat/ChatAttachments/ChatAttachmentsViewAll.svelte';
|
||||
|
||||
export { default as ChatForm } from './chat/ChatForm/ChatForm.svelte';
|
||||
export { default as ChatFormActionAttachmentsDropdown } from './chat/ChatForm/ChatFormActions/ChatFormActionAttachmentsDropdown.svelte';
|
||||
export { default as ChatFormActionFileAttachments } from './chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte';
|
||||
export { default as ChatFormActionRecord } from './chat/ChatForm/ChatFormActions/ChatFormActionRecord.svelte';
|
||||
export { default as ChatFormActions } from './chat/ChatForm/ChatFormActions/ChatFormActions.svelte';
|
||||
@@ -14,36 +22,38 @@ export { default as ChatFormActionSubmit } from './chat/ChatForm/ChatFormActions
|
||||
export { default as ChatFormFileInputInvisible } from './chat/ChatForm/ChatFormFileInputInvisible.svelte';
|
||||
export { default as ChatFormHelperText } from './chat/ChatForm/ChatFormHelperText.svelte';
|
||||
export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte';
|
||||
|
||||
export { default as ChatMessage } from './chat/ChatMessages/ChatMessage.svelte';
|
||||
export { default as ChatMessageActions } from './chat/ChatMessages/ChatMessageActions.svelte';
|
||||
export { default as ChatMessageAssistant } from './chat/ChatMessages/ChatMessageAssistant.svelte';
|
||||
export { default as ChatMessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
|
||||
export { default as ChatMessageEditForm } from './chat/ChatMessages/ChatMessageEditForm.svelte';
|
||||
export { default as ChatMessageStatistics } from './chat/ChatMessages/ChatMessageStatistics.svelte';
|
||||
export { default as ChatMessageSystem } from './chat/ChatMessages/ChatMessageSystem.svelte';
|
||||
export { default as ChatMessageThinkingBlock } from './chat/ChatMessages/ChatMessageThinkingBlock.svelte';
|
||||
export { default as ChatMessageUser } from './chat/ChatMessages/ChatMessageUser.svelte';
|
||||
export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte';
|
||||
export { default as MessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
|
||||
|
||||
export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte';
|
||||
export { default as ChatScreenDragOverlay } from './chat/ChatScreen/ChatScreenDragOverlay.svelte';
|
||||
export { default as ChatScreenForm } from './chat/ChatScreen/ChatScreenForm.svelte';
|
||||
export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte';
|
||||
export { default as ChatScreenProcessingInfo } from './chat/ChatScreen/ChatScreenProcessingInfo.svelte';
|
||||
|
||||
export { default as ChatSettings } from './chat/ChatSettings/ChatSettings.svelte';
|
||||
export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte';
|
||||
export { default as ChatSettingsFields } from './chat/ChatSettings/ChatSettingsFields.svelte';
|
||||
export { default as ChatSettingsImportExportTab } from './chat/ChatSettings/ChatSettingsImportExportTab.svelte';
|
||||
export { default as ChatSettingsParameterSourceIndicator } from './chat/ChatSettings/ChatSettingsParameterSourceIndicator.svelte';
|
||||
|
||||
export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
|
||||
export { default as ChatSidebarActions } from './chat/ChatSidebar/ChatSidebarActions.svelte';
|
||||
export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte';
|
||||
export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte';
|
||||
|
||||
// Dialogs
|
||||
|
||||
export { default as DialogChatAttachmentPreview } from './dialogs/DialogChatAttachmentPreview.svelte';
|
||||
export { default as DialogChatAttachmentsViewAll } from './dialogs/DialogChatAttachmentsViewAll.svelte';
|
||||
export { default as DialogChatError } from './dialogs/DialogChatError.svelte';
|
||||
export { default as DialogChatSettings } from './dialogs/DialogChatSettings.svelte';
|
||||
export { default as DialogCodePreview } from './dialogs/DialogCodePreview.svelte';
|
||||
export { default as DialogConfirmation } from './dialogs/DialogConfirmation.svelte';
|
||||
export { default as DialogConversationSelection } from './dialogs/DialogConversationSelection.svelte';
|
||||
export { default as DialogConversationTitleUpdate } from './dialogs/DialogConversationTitleUpdate.svelte';
|
||||
@@ -51,25 +61,8 @@ export { default as DialogEmptyFileAlert } from './dialogs/DialogEmptyFileAlert.
|
||||
export { default as DialogModelInformation } from './dialogs/DialogModelInformation.svelte';
|
||||
export { default as DialogModelNotAvailable } from './dialogs/DialogModelNotAvailable.svelte';
|
||||
|
||||
// Miscellanous
|
||||
|
||||
export { default as ActionButton } from './misc/ActionButton.svelte';
|
||||
export { default as ActionDropdown } from './misc/ActionDropdown.svelte';
|
||||
export { default as BadgeChatStatistic } from './misc/BadgeChatStatistic.svelte';
|
||||
export { default as BadgeInfo } from './misc/BadgeInfo.svelte';
|
||||
export { default as ModelBadge } from './models/ModelBadge.svelte';
|
||||
export { default as BadgeModality } from './misc/BadgeModality.svelte';
|
||||
export { default as ConversationSelection } from './misc/ConversationSelection.svelte';
|
||||
export { default as CopyToClipboardIcon } from './misc/CopyToClipboardIcon.svelte';
|
||||
export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte';
|
||||
export { default as MarkdownContent } from './misc/MarkdownContent.svelte';
|
||||
export { default as RemoveButton } from './misc/RemoveButton.svelte';
|
||||
export { default as SearchInput } from './misc/SearchInput.svelte';
|
||||
export { default as SyntaxHighlightedCode } from './misc/SyntaxHighlightedCode.svelte';
|
||||
export { default as ModelsSelector } from './models/ModelsSelector.svelte';
|
||||
|
||||
// Server
|
||||
|
||||
export { default as ServerStatus } from './server/ServerStatus.svelte';
|
||||
export { default as ServerErrorSplash } from './server/ServerErrorSplash.svelte';
|
||||
export { default as ServerLoadingSplash } from './server/ServerLoadingSplash.svelte';
|
||||
// Compatibility aliases
|
||||
export { default as ActionButton } from './actions/ActionIcon.svelte';
|
||||
export { default as ActionDropdown } from './navigation/DropdownMenuActions.svelte';
|
||||
export { default as CopyToClipboardIcon } from './actions/ActionIconCopyToClipboard.svelte';
|
||||
export { default as RemoveButton } from './actions/ActionIconRemove.svelte';
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
icon: Component;
|
||||
tooltip: string;
|
||||
variant?: 'default' | 'destructive' | 'outline' | 'secondary' | 'ghost' | 'link';
|
||||
size?: 'default' | 'sm' | 'lg' | 'icon';
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
onclick: () => void;
|
||||
'aria-label'?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
icon,
|
||||
tooltip,
|
||||
variant = 'ghost',
|
||||
size = 'sm',
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
onclick,
|
||||
'aria-label': ariaLabel
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<Button
|
||||
{variant}
|
||||
{size}
|
||||
{disabled}
|
||||
{onclick}
|
||||
class="h-6 w-6 p-0 {className} flex"
|
||||
aria-label={ariaLabel || tooltip}
|
||||
>
|
||||
{@const IconComponent = icon}
|
||||
<IconComponent class="h-3 w-3" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>{tooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
@@ -1,86 +0,0 @@
|
||||
<script lang="ts">
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { KeyboardShortcutInfo } from '$lib/components/app';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface ActionItem {
|
||||
icon: Component;
|
||||
label: string;
|
||||
onclick: (event: Event) => void;
|
||||
variant?: 'default' | 'destructive';
|
||||
disabled?: boolean;
|
||||
shortcut?: string[];
|
||||
separator?: boolean;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
triggerIcon: Component;
|
||||
triggerTooltip?: string;
|
||||
triggerClass?: string;
|
||||
actions: ActionItem[];
|
||||
align?: 'start' | 'center' | 'end';
|
||||
open?: boolean;
|
||||
}
|
||||
|
||||
let {
|
||||
triggerIcon,
|
||||
triggerTooltip,
|
||||
triggerClass = '',
|
||||
actions,
|
||||
align = 'end',
|
||||
open = $bindable(false)
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<DropdownMenu.Root bind:open>
|
||||
<DropdownMenu.Trigger
|
||||
class="flex h-6 w-6 cursor-pointer items-center justify-center rounded-md p-0 text-sm font-medium transition-colors hover:bg-accent hover:text-accent-foreground focus:bg-accent focus:text-accent-foreground focus:outline-none disabled:pointer-events-none disabled:opacity-50 data-[state=open]:bg-accent data-[state=open]:text-accent-foreground {triggerClass}"
|
||||
onclick={(e) => e.stopPropagation()}
|
||||
>
|
||||
{#if triggerTooltip}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
{@render iconComponent(triggerIcon, 'h-3 w-3')}
|
||||
<span class="sr-only">{triggerTooltip}</span>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content>
|
||||
<p>{triggerTooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
{@render iconComponent(triggerIcon, 'h-3 w-3')}
|
||||
{/if}
|
||||
</DropdownMenu.Trigger>
|
||||
|
||||
<DropdownMenu.Content {align} class="z-[999999] w-48">
|
||||
{#each actions as action, index (action.label)}
|
||||
{#if action.separator && index > 0}
|
||||
<DropdownMenu.Separator />
|
||||
{/if}
|
||||
|
||||
<DropdownMenu.Item
|
||||
onclick={action.onclick}
|
||||
variant={action.variant}
|
||||
disabled={action.disabled}
|
||||
class="flex items-center justify-between hover:[&>kbd]:opacity-100"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
{@render iconComponent(
|
||||
action.icon,
|
||||
`h-4 w-4 ${action.variant === 'destructive' ? 'text-destructive' : ''}`
|
||||
)}
|
||||
{action.label}
|
||||
</div>
|
||||
|
||||
{#if action.shortcut}
|
||||
<KeyboardShortcutInfo keys={action.shortcut} variant={action.variant} />
|
||||
{/if}
|
||||
</DropdownMenu.Item>
|
||||
{/each}
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
|
||||
{#snippet iconComponent(IconComponent: Component, className: string)}
|
||||
<IconComponent class={className} />
|
||||
{/snippet}
|
||||
@@ -1,44 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { BadgeInfo } from '$lib/components/app';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { copyToClipboard } from '$lib/utils';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
icon: Component;
|
||||
value: string | number;
|
||||
tooltipLabel?: string;
|
||||
}
|
||||
|
||||
let { class: className = '', icon: Icon, value, tooltipLabel }: Props = $props();
|
||||
|
||||
function handleClick() {
|
||||
void copyToClipboard(String(value));
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if tooltipLabel}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<BadgeInfo class={className} onclick={handleClick}>
|
||||
{#snippet icon()}
|
||||
<Icon class="h-3 w-3" />
|
||||
{/snippet}
|
||||
|
||||
{value}
|
||||
</BadgeInfo>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content>
|
||||
<p>{tooltipLabel}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
<BadgeInfo class={className} onclick={handleClick}>
|
||||
{#snippet icon()}
|
||||
<Icon class="h-3 w-3" />
|
||||
{/snippet}
|
||||
|
||||
{value}
|
||||
</BadgeInfo>
|
||||
{/if}
|
||||
@@ -1,27 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { cn } from '$lib/components/ui/utils';
|
||||
import type { Snippet } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
children: Snippet;
|
||||
class?: string;
|
||||
icon?: Snippet;
|
||||
onclick?: () => void;
|
||||
}
|
||||
|
||||
let { children, class: className = '', icon, onclick }: Props = $props();
|
||||
</script>
|
||||
|
||||
<button
|
||||
class={cn(
|
||||
'inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75',
|
||||
className
|
||||
)}
|
||||
{onclick}
|
||||
>
|
||||
{#if icon}
|
||||
{@render icon()}
|
||||
{/if}
|
||||
|
||||
{@render children()}
|
||||
</button>
|
||||
@@ -1,39 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { ModelModality } from '$lib/enums';
|
||||
import { MODALITY_ICONS, MODALITY_LABELS } from '$lib/constants/icons';
|
||||
import { cn } from '$lib/components/ui/utils';
|
||||
|
||||
type DisplayableModality = ModelModality.VISION | ModelModality.AUDIO;
|
||||
|
||||
interface Props {
|
||||
modalities: ModelModality[];
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { modalities, class: className = '' }: Props = $props();
|
||||
|
||||
// Filter to only modalities that have icons (VISION, AUDIO)
|
||||
const displayableModalities = $derived(
|
||||
modalities.filter(
|
||||
(m): m is DisplayableModality => m === ModelModality.VISION || m === ModelModality.AUDIO
|
||||
)
|
||||
);
|
||||
</script>
|
||||
|
||||
{#each displayableModalities as modality, index (index)}
|
||||
{@const IconComponent = MODALITY_ICONS[modality]}
|
||||
{@const label = MODALITY_LABELS[modality]}
|
||||
|
||||
<span
|
||||
class={cn(
|
||||
'inline-flex items-center gap-1 rounded-md bg-muted px-2 py-1 text-xs font-medium',
|
||||
className
|
||||
)}
|
||||
>
|
||||
{#if IconComponent}
|
||||
<IconComponent class="h-3 w-3" />
|
||||
{/if}
|
||||
|
||||
{label}
|
||||
</span>
|
||||
{/each}
|
||||
@@ -1,18 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { Copy } from '@lucide/svelte';
|
||||
import { copyToClipboard } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
ariaLabel?: string;
|
||||
canCopy?: boolean;
|
||||
text: string;
|
||||
}
|
||||
|
||||
let { ariaLabel = 'Copy to clipboard', canCopy = true, text }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Copy
|
||||
class="h-3 w-3 flex-shrink-0 cursor-{canCopy ? 'pointer' : 'not-allowed'}"
|
||||
aria-label={ariaLabel}
|
||||
onclick={() => canCopy && copyToClipboard(text)}
|
||||
/>
|
||||
@@ -1,88 +0,0 @@
|
||||
<script lang="ts">
|
||||
import type { Snippet } from 'svelte';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import { cn } from '$lib/components/ui/utils';
|
||||
import { SearchInput } from '$lib/components/app';
|
||||
|
||||
interface Props {
|
||||
open?: boolean;
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
placeholder?: string;
|
||||
searchValue?: string;
|
||||
onSearchChange?: (value: string) => void;
|
||||
onSearchKeyDown?: (event: KeyboardEvent) => void;
|
||||
align?: 'start' | 'center' | 'end';
|
||||
contentClass?: string;
|
||||
emptyMessage?: string;
|
||||
isEmpty?: boolean;
|
||||
disabled?: boolean;
|
||||
trigger: Snippet;
|
||||
children: Snippet;
|
||||
footer?: Snippet;
|
||||
}
|
||||
|
||||
let {
|
||||
open = $bindable(false),
|
||||
onOpenChange,
|
||||
placeholder = 'Search...',
|
||||
searchValue = $bindable(''),
|
||||
onSearchChange,
|
||||
onSearchKeyDown,
|
||||
align = 'start',
|
||||
contentClass = 'w-72',
|
||||
emptyMessage = 'No items found',
|
||||
isEmpty = false,
|
||||
disabled = false,
|
||||
trigger,
|
||||
children,
|
||||
footer
|
||||
}: Props = $props();
|
||||
|
||||
function handleOpenChange(newOpen: boolean) {
|
||||
open = newOpen;
|
||||
|
||||
if (!newOpen) {
|
||||
searchValue = '';
|
||||
onSearchChange?.('');
|
||||
}
|
||||
|
||||
onOpenChange?.(newOpen);
|
||||
}
|
||||
</script>
|
||||
|
||||
<DropdownMenu.Root bind:open onOpenChange={handleOpenChange}>
|
||||
<DropdownMenu.Trigger
|
||||
{disabled}
|
||||
onclick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
}}
|
||||
>
|
||||
{@render trigger()}
|
||||
</DropdownMenu.Trigger>
|
||||
|
||||
<DropdownMenu.Content {align} class={cn(contentClass, 'pt-0')}>
|
||||
<div class="sticky top-0 z-10 mb-2 bg-popover p-1 pt-2">
|
||||
<SearchInput
|
||||
{placeholder}
|
||||
bind:value={searchValue}
|
||||
onInput={onSearchChange}
|
||||
onKeyDown={onSearchKeyDown}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class={cn('overflow-y-auto')}>
|
||||
{@render children()}
|
||||
|
||||
{#if isEmpty}
|
||||
<div class="px-2 py-3 text-center text-sm text-muted-foreground">{emptyMessage}</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if footer}
|
||||
<DropdownMenu.Separator />
|
||||
|
||||
{@render footer()}
|
||||
{/if}
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
@@ -1,872 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { remark } from 'remark';
|
||||
import remarkBreaks from 'remark-breaks';
|
||||
import remarkGfm from 'remark-gfm';
|
||||
import remarkMath from 'remark-math';
|
||||
import rehypeHighlight from 'rehype-highlight';
|
||||
import remarkRehype from 'remark-rehype';
|
||||
import rehypeKatex from 'rehype-katex';
|
||||
import rehypeStringify from 'rehype-stringify';
|
||||
import type { Root as HastRoot, RootContent as HastRootContent } from 'hast';
|
||||
import type { Root as MdastRoot } from 'mdast';
|
||||
import { browser } from '$app/environment';
|
||||
import { onDestroy, tick } from 'svelte';
|
||||
import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer';
|
||||
import { rehypeEnhanceLinks } from '$lib/markdown/enhance-links';
|
||||
import { rehypeEnhanceCodeBlocks } from '$lib/markdown/enhance-code-blocks';
|
||||
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
|
||||
import { copyCodeToClipboard, preprocessLaTeX } from '$lib/utils';
|
||||
import '$styles/katex-custom.scss';
|
||||
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
|
||||
import githubLightCss from 'highlight.js/styles/github.css?inline';
|
||||
import { mode } from 'mode-watcher';
|
||||
import CodePreviewDialog from './CodePreviewDialog.svelte';
|
||||
|
||||
interface Props {
|
||||
content: string;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
interface MarkdownBlock {
|
||||
id: string;
|
||||
html: string;
|
||||
}
|
||||
|
||||
let { content, class: className = '' }: Props = $props();
|
||||
|
||||
let containerRef = $state<HTMLDivElement>();
|
||||
let renderedBlocks = $state<MarkdownBlock[]>([]);
|
||||
let unstableBlockHtml = $state('');
|
||||
let previewDialogOpen = $state(false);
|
||||
let previewCode = $state('');
|
||||
let previewLanguage = $state('text');
|
||||
|
||||
let pendingMarkdown: string | null = null;
|
||||
let isProcessing = false;
|
||||
|
||||
const themeStyleId = `highlight-theme-${(window.idxThemeStyle = (window.idxThemeStyle ?? 0) + 1)}`;
|
||||
|
||||
let processor = $derived(() => {
|
||||
return remark()
|
||||
.use(remarkGfm) // GitHub Flavored Markdown
|
||||
.use(remarkMath) // Parse $inline$ and $$block$$ math
|
||||
.use(remarkBreaks) // Convert line breaks to <br>
|
||||
.use(remarkLiteralHtml) // Treat raw HTML as literal text with preserved indentation
|
||||
.use(remarkRehype) // Convert Markdown AST to rehype
|
||||
.use(rehypeKatex) // Render math using KaTeX
|
||||
.use(rehypeHighlight) // Add syntax highlighting
|
||||
.use(rehypeRestoreTableHtml) // Restore limited HTML (e.g., <br>, <ul>) inside Markdown tables
|
||||
.use(rehypeEnhanceLinks) // Add target="_blank" to links
|
||||
.use(rehypeEnhanceCodeBlocks) // Wrap code blocks with header and actions
|
||||
.use(rehypeStringify, { allowDangerousHtml: true }); // Convert to HTML string
|
||||
});
|
||||
|
||||
/**
|
||||
* Removes click event listeners from copy and preview buttons.
|
||||
* Called on component destroy.
|
||||
*/
|
||||
function cleanupEventListeners() {
|
||||
if (!containerRef) return;
|
||||
|
||||
const copyButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
const previewButtons = containerRef.querySelectorAll<HTMLButtonElement>('.preview-code-btn');
|
||||
|
||||
for (const button of copyButtons) {
|
||||
button.removeEventListener('click', handleCopyClick);
|
||||
}
|
||||
|
||||
for (const button of previewButtons) {
|
||||
button.removeEventListener('click', handlePreviewClick);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes this component's highlight.js theme style from the document head.
|
||||
* Called on component destroy to clean up injected styles.
|
||||
*/
|
||||
function cleanupHighlightTheme() {
|
||||
if (!browser) return;
|
||||
|
||||
const existingTheme = document.getElementById(themeStyleId);
|
||||
existingTheme?.remove();
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the appropriate highlight.js theme based on dark/light mode.
|
||||
* Injects a scoped style element into the document head.
|
||||
* @param isDark - Whether to load the dark theme (true) or light theme (false)
|
||||
*/
|
||||
function loadHighlightTheme(isDark: boolean) {
|
||||
if (!browser) return;
|
||||
|
||||
const existingTheme = document.getElementById(themeStyleId);
|
||||
existingTheme?.remove();
|
||||
|
||||
const style = document.createElement('style');
|
||||
style.id = themeStyleId;
|
||||
style.textContent = isDark ? githubDarkCss : githubLightCss;
|
||||
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts code information from a button click target within a code block.
|
||||
* @param target - The clicked button element
|
||||
* @returns Object with rawCode and language, or null if extraction fails
|
||||
*/
|
||||
function getCodeInfoFromTarget(target: HTMLElement) {
|
||||
const wrapper = target.closest('.code-block-wrapper');
|
||||
|
||||
if (!wrapper) {
|
||||
console.error('No wrapper found');
|
||||
return null;
|
||||
}
|
||||
|
||||
const codeElement = wrapper.querySelector<HTMLElement>('code[data-code-id]');
|
||||
|
||||
if (!codeElement) {
|
||||
console.error('No code element found in wrapper');
|
||||
return null;
|
||||
}
|
||||
|
||||
const rawCode = codeElement.textContent ?? '';
|
||||
|
||||
const languageLabel = wrapper.querySelector<HTMLElement>('.code-language');
|
||||
const language = languageLabel?.textContent?.trim() || 'text';
|
||||
|
||||
return { rawCode, language };
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a unique identifier for a HAST node based on its position.
|
||||
* Used for stable block identification during incremental rendering.
|
||||
* @param node - The HAST root content node
|
||||
* @param indexFallback - Fallback index if position is unavailable
|
||||
* @returns Unique string identifier for the node
|
||||
*/
|
||||
function getHastNodeId(node: HastRootContent, indexFallback: number): string {
|
||||
const position = node.position;
|
||||
|
||||
if (position?.start?.offset != null && position?.end?.offset != null) {
|
||||
return `hast-${position.start.offset}-${position.end.offset}`;
|
||||
}
|
||||
|
||||
return `${node.type}-${indexFallback}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles click events on copy buttons within code blocks.
|
||||
* Copies the raw code content to the clipboard.
|
||||
* @param event - The click event from the copy button
|
||||
*/
|
||||
async function handleCopyClick(event: Event) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
|
||||
const target = event.currentTarget as HTMLButtonElement | null;
|
||||
|
||||
if (!target) {
|
||||
return;
|
||||
}
|
||||
|
||||
const info = getCodeInfoFromTarget(target);
|
||||
|
||||
if (!info) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await copyCodeToClipboard(info.rawCode);
|
||||
} catch (error) {
|
||||
console.error('Failed to copy code:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles preview dialog open state changes.
|
||||
* Clears preview content when dialog is closed.
|
||||
* @param open - Whether the dialog is being opened or closed
|
||||
*/
|
||||
function handlePreviewDialogOpenChange(open: boolean) {
|
||||
previewDialogOpen = open;
|
||||
|
||||
if (!open) {
|
||||
previewCode = '';
|
||||
previewLanguage = 'text';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles click events on preview buttons within HTML code blocks.
|
||||
* Opens a preview dialog with the rendered HTML content.
|
||||
* @param event - The click event from the preview button
|
||||
*/
|
||||
function handlePreviewClick(event: Event) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
|
||||
const target = event.currentTarget as HTMLButtonElement | null;
|
||||
|
||||
if (!target) {
|
||||
return;
|
||||
}
|
||||
|
||||
const info = getCodeInfoFromTarget(target);
|
||||
|
||||
if (!info) {
|
||||
return;
|
||||
}
|
||||
|
||||
previewCode = info.rawCode;
|
||||
previewLanguage = info.language;
|
||||
previewDialogOpen = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes markdown content into stable and unstable HTML blocks.
|
||||
* Uses incremental rendering: stable blocks are cached, unstable block is re-rendered.
|
||||
* @param markdown - The raw markdown string to process
|
||||
*/
|
||||
async function processMarkdown(markdown: string) {
|
||||
if (!markdown) {
|
||||
renderedBlocks = [];
|
||||
unstableBlockHtml = '';
|
||||
return;
|
||||
}
|
||||
|
||||
const normalized = preprocessLaTeX(markdown);
|
||||
const processorInstance = processor();
|
||||
const ast = processorInstance.parse(normalized) as MdastRoot;
|
||||
const processedRoot = (await processorInstance.run(ast)) as HastRoot;
|
||||
const processedChildren = processedRoot.children ?? [];
|
||||
const stableCount = Math.max(processedChildren.length - 1, 0);
|
||||
const nextBlocks: MarkdownBlock[] = [];
|
||||
|
||||
for (let index = 0; index < stableCount; index++) {
|
||||
const hastChild = processedChildren[index];
|
||||
const id = getHastNodeId(hastChild, index);
|
||||
const existing = renderedBlocks[index];
|
||||
|
||||
if (existing && existing.id === id) {
|
||||
nextBlocks.push(existing);
|
||||
continue;
|
||||
}
|
||||
|
||||
const html = stringifyProcessedNode(
|
||||
processorInstance,
|
||||
processedRoot,
|
||||
processedChildren[index]
|
||||
);
|
||||
|
||||
nextBlocks.push({ id, html });
|
||||
}
|
||||
|
||||
let unstableHtml = '';
|
||||
|
||||
if (processedChildren.length > stableCount) {
|
||||
const unstableChild = processedChildren[stableCount];
|
||||
unstableHtml = stringifyProcessedNode(processorInstance, processedRoot, unstableChild);
|
||||
}
|
||||
|
||||
renderedBlocks = nextBlocks;
|
||||
await tick(); // Force DOM sync before updating unstable HTML block
|
||||
unstableBlockHtml = unstableHtml;
|
||||
}
|
||||
|
||||
/**
|
||||
* Attaches click event listeners to copy and preview buttons in code blocks.
|
||||
* Uses data-listener-bound attribute to prevent duplicate bindings.
|
||||
*/
|
||||
function setupCodeBlockActions() {
|
||||
if (!containerRef) return;
|
||||
|
||||
const wrappers = containerRef.querySelectorAll<HTMLElement>('.code-block-wrapper');
|
||||
|
||||
for (const wrapper of wrappers) {
|
||||
const copyButton = wrapper.querySelector<HTMLButtonElement>('.copy-code-btn');
|
||||
const previewButton = wrapper.querySelector<HTMLButtonElement>('.preview-code-btn');
|
||||
|
||||
if (copyButton && copyButton.dataset.listenerBound !== 'true') {
|
||||
copyButton.dataset.listenerBound = 'true';
|
||||
copyButton.addEventListener('click', handleCopyClick);
|
||||
}
|
||||
|
||||
if (previewButton && previewButton.dataset.listenerBound !== 'true') {
|
||||
previewButton.dataset.listenerBound = 'true';
|
||||
previewButton.addEventListener('click', handlePreviewClick);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a single HAST node to an enhanced HTML string.
|
||||
* Applies link and code block enhancements to the output.
|
||||
* @param processorInstance - The remark/rehype processor instance
|
||||
* @param processedRoot - The full processed HAST root (for context)
|
||||
* @param child - The specific HAST child node to stringify
|
||||
* @returns Enhanced HTML string representation of the node
|
||||
*/
|
||||
function stringifyProcessedNode(
|
||||
processorInstance: ReturnType<typeof processor>,
|
||||
processedRoot: HastRoot,
|
||||
child: unknown
|
||||
) {
|
||||
const root: HastRoot = {
|
||||
...(processedRoot as HastRoot),
|
||||
children: [child as never]
|
||||
};
|
||||
|
||||
return processorInstance.stringify(root);
|
||||
}
|
||||
|
||||
/**
|
||||
* Queues markdown for processing with coalescing support.
|
||||
* Only processes the latest markdown when multiple updates arrive quickly.
|
||||
* @param markdown - The markdown content to render
|
||||
*/
|
||||
async function updateRenderedBlocks(markdown: string) {
|
||||
pendingMarkdown = markdown;
|
||||
|
||||
if (isProcessing) {
|
||||
return;
|
||||
}
|
||||
|
||||
isProcessing = true;
|
||||
|
||||
try {
|
||||
while (pendingMarkdown !== null) {
|
||||
const nextMarkdown = pendingMarkdown;
|
||||
pendingMarkdown = null;
|
||||
|
||||
await processMarkdown(nextMarkdown);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to process markdown:', error);
|
||||
renderedBlocks = [];
|
||||
unstableBlockHtml = markdown.replace(/\n/g, '<br>');
|
||||
} finally {
|
||||
isProcessing = false;
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
const currentMode = mode.current;
|
||||
const isDark = currentMode === 'dark';
|
||||
|
||||
loadHighlightTheme(isDark);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
updateRenderedBlocks(content);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
const hasRenderedBlocks = renderedBlocks.length > 0;
|
||||
const hasUnstableBlock = Boolean(unstableBlockHtml);
|
||||
|
||||
if ((hasRenderedBlocks || hasUnstableBlock) && containerRef) {
|
||||
setupCodeBlockActions();
|
||||
}
|
||||
});
|
||||
|
||||
onDestroy(() => {
|
||||
cleanupEventListeners();
|
||||
cleanupHighlightTheme();
|
||||
});
|
||||
</script>
|
||||
|
||||
<div bind:this={containerRef} class={className}>
|
||||
{#each renderedBlocks as block (block.id)}
|
||||
<div class="markdown-block" data-block-id={block.id}>
|
||||
<!-- eslint-disable-next-line no-at-html-tags -->
|
||||
{@html block.html}
|
||||
</div>
|
||||
{/each}
|
||||
|
||||
{#if unstableBlockHtml}
|
||||
<div class="markdown-block markdown-block--unstable" data-block-id="unstable">
|
||||
<!-- eslint-disable-next-line no-at-html-tags -->
|
||||
{@html unstableBlockHtml}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<CodePreviewDialog
|
||||
open={previewDialogOpen}
|
||||
code={previewCode}
|
||||
language={previewLanguage}
|
||||
onOpenChange={handlePreviewDialogOpenChange}
|
||||
/>
|
||||
|
||||
<style>
|
||||
.markdown-block,
|
||||
.markdown-block--unstable {
|
||||
display: contents;
|
||||
}
|
||||
|
||||
/* Base typography styles */
|
||||
div :global(p:not(:last-child)) {
|
||||
margin-bottom: 1rem;
|
||||
line-height: 1.75;
|
||||
}
|
||||
|
||||
div :global(:is(h1, h2, h3, h4, h5, h6):first-child) {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
/* Headers with consistent spacing */
|
||||
div :global(h1) {
|
||||
font-size: 1.875rem;
|
||||
font-weight: 700;
|
||||
line-height: 1.2;
|
||||
margin: 1.5rem 0 0.75rem 0;
|
||||
}
|
||||
|
||||
div :global(h2) {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 600;
|
||||
line-height: 1.3;
|
||||
margin: 1.25rem 0 0.5rem 0;
|
||||
}
|
||||
|
||||
div :global(h3) {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
margin: 1.5rem 0 0.5rem 0;
|
||||
line-height: 1.4;
|
||||
}
|
||||
|
||||
div :global(h4) {
|
||||
font-size: 1.125rem;
|
||||
font-weight: 600;
|
||||
margin: 0.75rem 0 0.25rem 0;
|
||||
}
|
||||
|
||||
div :global(h5) {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
margin: 0.5rem 0 0.25rem 0;
|
||||
}
|
||||
|
||||
div :global(h6) {
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
margin: 0.5rem 0 0.25rem 0;
|
||||
}
|
||||
|
||||
/* Text formatting */
|
||||
div :global(strong) {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
div :global(em) {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
div :global(del) {
|
||||
text-decoration: line-through;
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
/* Inline code */
|
||||
div :global(code:not(pre code)) {
|
||||
background: var(--muted);
|
||||
color: var(--muted-foreground);
|
||||
padding: 0.125rem 0.375rem;
|
||||
border-radius: 0.375rem;
|
||||
font-size: 0.875rem;
|
||||
font-family:
|
||||
ui-monospace, SFMono-Regular, 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas,
|
||||
'Liberation Mono', Menlo, monospace;
|
||||
}
|
||||
|
||||
/* Links */
|
||||
div :global(a) {
|
||||
color: var(--primary);
|
||||
text-decoration: underline;
|
||||
text-underline-offset: 2px;
|
||||
transition: color 0.2s ease;
|
||||
overflow-wrap: anywhere;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
div :global(a:hover) {
|
||||
color: var(--primary);
|
||||
}
|
||||
|
||||
/* Lists */
|
||||
div :global(ul) {
|
||||
list-style-type: disc;
|
||||
margin-left: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
div :global(ol) {
|
||||
list-style-type: decimal;
|
||||
margin-left: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
div :global(li) {
|
||||
margin-bottom: 0.25rem;
|
||||
padding-left: 0.5rem;
|
||||
}
|
||||
|
||||
div :global(li::marker) {
|
||||
color: var(--muted-foreground);
|
||||
}
|
||||
|
||||
/* Nested lists */
|
||||
div :global(ul ul) {
|
||||
list-style-type: circle;
|
||||
margin-top: 0.25rem;
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
|
||||
div :global(ol ol) {
|
||||
list-style-type: lower-alpha;
|
||||
margin-top: 0.25rem;
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
|
||||
/* Task lists */
|
||||
div :global(.task-list-item) {
|
||||
list-style: none;
|
||||
margin-left: 0;
|
||||
padding-left: 0;
|
||||
}
|
||||
|
||||
div :global(.task-list-item-checkbox) {
|
||||
margin-right: 0.5rem;
|
||||
margin-top: 0.125rem;
|
||||
}
|
||||
|
||||
/* Blockquotes */
|
||||
div :global(blockquote) {
|
||||
border-left: 4px solid var(--border);
|
||||
padding: 0.5rem 1rem;
|
||||
margin: 1.5rem 0;
|
||||
font-style: italic;
|
||||
color: var(--muted-foreground);
|
||||
background: var(--muted);
|
||||
border-radius: 0 0.375rem 0.375rem 0;
|
||||
}
|
||||
|
||||
/* Tables */
|
||||
div :global(table) {
|
||||
width: 100%;
|
||||
margin: 1.5rem 0;
|
||||
border-collapse: collapse;
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 0.375rem;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
div :global(th) {
|
||||
background: hsl(var(--muted) / 0.3);
|
||||
border: 1px solid var(--border);
|
||||
padding: 0.5rem 0.75rem;
|
||||
text-align: left;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
div :global(td) {
|
||||
border: 1px solid var(--border);
|
||||
padding: 0.5rem 0.75rem;
|
||||
}
|
||||
|
||||
div :global(tr:nth-child(even)) {
|
||||
background: hsl(var(--muted) / 0.1);
|
||||
}
|
||||
|
||||
/* User message markdown should keep table borders visible on light primary backgrounds */
|
||||
div.markdown-user-content :global(table),
|
||||
div.markdown-user-content :global(th),
|
||||
div.markdown-user-content :global(td),
|
||||
div.markdown-user-content :global(.table-wrapper) {
|
||||
border-color: currentColor;
|
||||
}
|
||||
|
||||
/* Horizontal rules */
|
||||
div :global(hr) {
|
||||
border: none;
|
||||
border-top: 1px solid var(--border);
|
||||
margin: 1.5rem 0;
|
||||
}
|
||||
|
||||
/* Images */
|
||||
div :global(img) {
|
||||
border-radius: 0.5rem;
|
||||
box-shadow:
|
||||
0 1px 3px 0 rgb(0 0 0 / 0.1),
|
||||
0 1px 2px -1px rgb(0 0 0 / 0.1);
|
||||
margin: 1.5rem 0;
|
||||
max-width: 100%;
|
||||
height: auto;
|
||||
}
|
||||
|
||||
/* Code blocks */
|
||||
|
||||
div :global(.code-block-wrapper) {
|
||||
margin: 1.5rem 0;
|
||||
border-radius: 0.75rem;
|
||||
overflow: hidden;
|
||||
border: 1px solid var(--border);
|
||||
background: var(--code-background);
|
||||
}
|
||||
|
||||
div :global(.code-block-header) {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.5rem 1rem;
|
||||
background: hsl(var(--muted) / 0.5);
|
||||
border-bottom: 1px solid var(--border);
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
div :global(.code-language) {
|
||||
color: var(--code-foreground);
|
||||
font-weight: 500;
|
||||
font-family:
|
||||
ui-monospace, SFMono-Regular, 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas,
|
||||
'Liberation Mono', Menlo, monospace;
|
||||
text-transform: uppercase;
|
||||
font-size: 0.75rem;
|
||||
letter-spacing: 0.05em;
|
||||
}
|
||||
|
||||
div :global(.code-block-actions) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
div :global(.copy-code-btn),
|
||||
div :global(.preview-code-btn) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 0;
|
||||
background: transparent;
|
||||
color: var(--code-foreground);
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
div :global(.copy-code-btn:hover),
|
||||
div :global(.preview-code-btn:hover) {
|
||||
transform: scale(1.05);
|
||||
}
|
||||
|
||||
div :global(.copy-code-btn:active),
|
||||
div :global(.preview-code-btn:active) {
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
div :global(.code-block-wrapper pre) {
|
||||
background: transparent;
|
||||
padding: 1rem;
|
||||
margin: 0;
|
||||
overflow-x: auto;
|
||||
border-radius: 0;
|
||||
border: none;
|
||||
font-size: 0.875rem;
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
div :global(pre) {
|
||||
background: var(--muted);
|
||||
margin: 1.5rem 0;
|
||||
overflow-x: auto;
|
||||
border-radius: 1rem;
|
||||
border: none;
|
||||
}
|
||||
|
||||
div :global(code) {
|
||||
background: transparent;
|
||||
color: var(--code-foreground);
|
||||
}
|
||||
|
||||
/* Mentions and hashtags */
|
||||
div :global(.mention) {
|
||||
color: hsl(var(--primary));
|
||||
font-weight: 500;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
div :global(.mention:hover) {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
div :global(.hashtag) {
|
||||
color: hsl(var(--primary));
|
||||
font-weight: 500;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
div :global(.hashtag:hover) {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
/* Advanced table enhancements */
|
||||
div :global(table) {
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
div :global(table:hover) {
|
||||
box-shadow:
|
||||
0 4px 6px -1px rgb(0 0 0 / 0.1),
|
||||
0 2px 4px -2px rgb(0 0 0 / 0.1);
|
||||
}
|
||||
|
||||
div :global(th:hover),
|
||||
div :global(td:hover) {
|
||||
background: var(--muted);
|
||||
}
|
||||
|
||||
/* Disable hover effects when rendering user messages */
|
||||
.markdown-user-content :global(a),
|
||||
.markdown-user-content :global(a:hover) {
|
||||
color: var(--primary-foreground);
|
||||
}
|
||||
|
||||
.markdown-user-content :global(table:hover) {
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
.markdown-user-content :global(th:hover),
|
||||
.markdown-user-content :global(td:hover) {
|
||||
background: inherit;
|
||||
}
|
||||
|
||||
/* Enhanced blockquotes */
|
||||
div :global(blockquote) {
|
||||
transition: all 0.2s ease;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
div :global(blockquote:hover) {
|
||||
border-left-width: 6px;
|
||||
background: var(--muted);
|
||||
transform: translateX(2px);
|
||||
}
|
||||
|
||||
div :global(blockquote::before) {
|
||||
content: '"';
|
||||
position: absolute;
|
||||
top: -0.5rem;
|
||||
left: 0.5rem;
|
||||
font-size: 3rem;
|
||||
color: var(--muted-foreground);
|
||||
font-family: serif;
|
||||
line-height: 1;
|
||||
}
|
||||
|
||||
/* Enhanced images */
|
||||
div :global(img) {
|
||||
transition: all 0.3s ease;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
div :global(img:hover) {
|
||||
transform: scale(1.02);
|
||||
box-shadow:
|
||||
0 10px 15px -3px rgb(0 0 0 / 0.1),
|
||||
0 4px 6px -4px rgb(0 0 0 / 0.1);
|
||||
}
|
||||
|
||||
/* Image zoom overlay */
|
||||
div :global(.image-zoom-overlay) {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background: rgba(0, 0, 0, 0.8);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
z-index: 1000;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
div :global(.image-zoom-overlay img) {
|
||||
max-width: 90vw;
|
||||
max-height: 90vh;
|
||||
border-radius: 0.5rem;
|
||||
box-shadow: 0 25px 50px -12px rgb(0 0 0 / 0.25);
|
||||
}
|
||||
|
||||
/* Enhanced horizontal rules */
|
||||
div :global(hr) {
|
||||
border: none;
|
||||
height: 2px;
|
||||
background: linear-gradient(to right, transparent, var(--border), transparent);
|
||||
margin: 2rem 0;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
div :global(hr::after) {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
width: 1rem;
|
||||
height: 1rem;
|
||||
background: var(--border);
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
/* Scrollable tables */
|
||||
div :global(.table-wrapper) {
|
||||
overflow-x: auto;
|
||||
margin: 1.5rem 0;
|
||||
border-radius: 0.5rem;
|
||||
border: 1px solid var(--border);
|
||||
}
|
||||
|
||||
div :global(.table-wrapper table) {
|
||||
margin: 0;
|
||||
border: none;
|
||||
}
|
||||
|
||||
/* Responsive adjustments */
|
||||
@media (max-width: 640px) {
|
||||
div :global(h1) {
|
||||
font-size: 1.5rem;
|
||||
}
|
||||
|
||||
div :global(h2) {
|
||||
font-size: 1.25rem;
|
||||
}
|
||||
|
||||
div :global(h3) {
|
||||
font-size: 1.125rem;
|
||||
}
|
||||
|
||||
div :global(table) {
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
div :global(th),
|
||||
div :global(td) {
|
||||
padding: 0.375rem 0.5rem;
|
||||
}
|
||||
|
||||
div :global(.table-wrapper) {
|
||||
margin: 0.5rem -1rem;
|
||||
border-radius: 0;
|
||||
border-left: none;
|
||||
border-right: none;
|
||||
}
|
||||
}
|
||||
|
||||
/* Dark mode adjustments */
|
||||
@media (prefers-color-scheme: dark) {
|
||||
div :global(blockquote:hover) {
|
||||
background: var(--muted);
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -1,26 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
onRemove?: (id: string) => void;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { id, onRemove, class: className = '' }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 w-6 bg-white/20 p-0 hover:bg-white/30 {className}"
|
||||
onclick={(e) => {
|
||||
e.stopPropagation();
|
||||
onRemove?.(id);
|
||||
}}
|
||||
aria-label="Remove file"
|
||||
>
|
||||
<X class="h-3 w-3" />
|
||||
</Button>
|
||||
@@ -1,97 +0,0 @@
|
||||
<script lang="ts">
|
||||
import hljs from 'highlight.js';
|
||||
import { browser } from '$app/environment';
|
||||
import { mode } from 'mode-watcher';
|
||||
|
||||
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
|
||||
import githubLightCss from 'highlight.js/styles/github.css?inline';
|
||||
|
||||
interface Props {
|
||||
code: string;
|
||||
language?: string;
|
||||
class?: string;
|
||||
maxHeight?: string;
|
||||
maxWidth?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
code,
|
||||
language = 'text',
|
||||
class: className = '',
|
||||
maxHeight = '60vh',
|
||||
maxWidth = ''
|
||||
}: Props = $props();
|
||||
|
||||
let highlightedHtml = $state('');
|
||||
|
||||
function loadHighlightTheme(isDark: boolean) {
|
||||
if (!browser) return;
|
||||
|
||||
const existingThemes = document.querySelectorAll('style[data-highlight-theme-preview]');
|
||||
existingThemes.forEach((style) => style.remove());
|
||||
|
||||
const style = document.createElement('style');
|
||||
style.setAttribute('data-highlight-theme-preview', 'true');
|
||||
style.textContent = isDark ? githubDarkCss : githubLightCss;
|
||||
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
const currentMode = mode.current;
|
||||
const isDark = currentMode === 'dark';
|
||||
|
||||
loadHighlightTheme(isDark);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (!code) {
|
||||
highlightedHtml = '';
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Check if the language is supported
|
||||
const lang = language.toLowerCase();
|
||||
const isSupported = hljs.getLanguage(lang);
|
||||
|
||||
if (isSupported) {
|
||||
const result = hljs.highlight(code, { language: lang });
|
||||
highlightedHtml = result.value;
|
||||
} else {
|
||||
// Try auto-detection or fallback to plain text
|
||||
const result = hljs.highlightAuto(code);
|
||||
highlightedHtml = result.value;
|
||||
}
|
||||
} catch {
|
||||
// Fallback to escaped plain text
|
||||
highlightedHtml = code.replace(/&/g, '&').replace(/</g, '<').replace(/>/g, '>');
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="code-preview-wrapper overflow-auto rounded-lg border border-border bg-muted {className}"
|
||||
style="max-height: {maxHeight}; max-width: {maxWidth};"
|
||||
>
|
||||
<!-- Needs to be formatted as single line for proper rendering -->
|
||||
<pre class="m-0 overflow-x-auto p-4"><code class="hljs text-sm leading-relaxed"
|
||||
>{@html highlightedHtml}</code
|
||||
></pre>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.code-preview-wrapper {
|
||||
font-family:
|
||||
ui-monospace, SFMono-Regular, 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas,
|
||||
'Liberation Mono', Menlo, monospace;
|
||||
}
|
||||
|
||||
.code-preview-wrapper pre {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.code-preview-wrapper code {
|
||||
background: transparent;
|
||||
}
|
||||
</style>
|
||||
@@ -1,6 +1,6 @@
|
||||
<script lang="ts">
|
||||
import { Package } from '@lucide/svelte';
|
||||
import { BadgeInfo, CopyToClipboardIcon } from '$lib/components/app';
|
||||
import { BadgeInfo, ActionIconCopyToClipboard } from '$lib/components/app';
|
||||
import { modelsStore } from '$lib/stores/models.svelte';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
@@ -34,7 +34,7 @@
|
||||
{model}
|
||||
|
||||
{#if showCopyIcon}
|
||||
<CopyToClipboardIcon text={model || ''} ariaLabel="Copy model name" />
|
||||
<ActionIconCopyToClipboard text={model || ''} ariaLabel="Copy model name" />
|
||||
{/if}
|
||||
</BadgeInfo>
|
||||
{/snippet}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
<script lang="ts">
|
||||
import { onMount, tick } from 'svelte';
|
||||
import { ChevronDown, EyeOff, Loader2, MicOff, Package, Power } from '@lucide/svelte';
|
||||
import { onMount } from 'svelte';
|
||||
import { ChevronDown, Loader2, Package, Power } from '@lucide/svelte';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import * as Popover from '$lib/components/ui/popover';
|
||||
import { cn } from '$lib/components/ui/utils';
|
||||
import {
|
||||
modelsStore,
|
||||
@@ -11,13 +11,15 @@
|
||||
modelsUpdating,
|
||||
selectedModelId,
|
||||
routerModels,
|
||||
propsCacheVersion,
|
||||
singleModelName
|
||||
} from '$lib/stores/models.svelte';
|
||||
import { usedModalities, conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { ServerModelStatus } from '$lib/enums';
|
||||
import { KeyboardKey, ServerModelStatus } from '$lib/enums';
|
||||
import { isRouterMode } from '$lib/stores/server.svelte';
|
||||
import { DialogModelInformation, SearchInput } from '$lib/components/app';
|
||||
import {
|
||||
DialogModelInformation,
|
||||
DropdownMenuSearchable,
|
||||
TruncatedText
|
||||
} from '$lib/components/app';
|
||||
import type { ModelOption } from '$lib/types/models';
|
||||
|
||||
interface Props {
|
||||
@@ -29,11 +31,7 @@
|
||||
forceForegroundText?: boolean;
|
||||
/** When true, user's global selection takes priority over currentModel (for form selector) */
|
||||
useGlobalSelection?: boolean;
|
||||
/**
|
||||
* When provided, only consider modalities from messages BEFORE this message.
|
||||
* Used for regeneration - allows selecting models that don't support modalities
|
||||
* used in later messages.
|
||||
*/
|
||||
/** Optional compatibility prop for context-aware selectors. */
|
||||
upToMessageId?: string;
|
||||
}
|
||||
|
||||
@@ -44,7 +42,8 @@
|
||||
disabled = false,
|
||||
forceForegroundText = false,
|
||||
useGlobalSelection = false,
|
||||
upToMessageId
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
upToMessageId: _upToMessageId = undefined
|
||||
}: Props = $props();
|
||||
|
||||
let options = $derived(modelOptions());
|
||||
@@ -57,74 +56,11 @@
|
||||
// Reactive router models state - needed for proper reactivity of status checks
|
||||
let currentRouterModels = $derived(routerModels());
|
||||
|
||||
let requiredModalities = $derived(
|
||||
upToMessageId ? conversationsStore.getModalitiesUpToMessage(upToMessageId) : usedModalities()
|
||||
);
|
||||
|
||||
function getModelStatus(modelId: string): ServerModelStatus | null {
|
||||
const model = currentRouterModels.find((m) => m.id === modelId);
|
||||
return (model?.status?.value as ServerModelStatus) ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a model supports all modalities used in the conversation.
|
||||
* Returns true if the model can be selected, false if it should be disabled.
|
||||
*/
|
||||
function isModelCompatible(option: ModelOption): boolean {
|
||||
void propsCacheVersion();
|
||||
|
||||
const modelModalities = modelsStore.getModelModalities(option.model);
|
||||
|
||||
if (!modelModalities) {
|
||||
const status = getModelStatus(option.model);
|
||||
|
||||
if (status === ServerModelStatus.LOADED) {
|
||||
if (requiredModalities.vision || requiredModalities.audio) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (requiredModalities.vision && !modelModalities.vision) return false;
|
||||
if (requiredModalities.audio && !modelModalities.audio) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets missing modalities for a model.
|
||||
* Returns object with vision/audio booleans indicating what's missing.
|
||||
*/
|
||||
function getMissingModalities(option: ModelOption): { vision: boolean; audio: boolean } | null {
|
||||
void propsCacheVersion();
|
||||
|
||||
const modelModalities = modelsStore.getModelModalities(option.model);
|
||||
|
||||
if (!modelModalities) {
|
||||
const status = getModelStatus(option.model);
|
||||
|
||||
if (status === ServerModelStatus.LOADED) {
|
||||
const missing = {
|
||||
vision: requiredModalities.vision,
|
||||
audio: requiredModalities.audio
|
||||
};
|
||||
|
||||
if (missing.vision || missing.audio) return missing;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
const missing = {
|
||||
vision: requiredModalities.vision && !modelModalities.vision,
|
||||
audio: requiredModalities.audio && !modelModalities.audio
|
||||
};
|
||||
|
||||
if (!missing.vision && !missing.audio) return null;
|
||||
|
||||
return missing;
|
||||
}
|
||||
|
||||
let isHighlightedCurrentModelActive = $derived(
|
||||
!isRouter || !currentModel
|
||||
? false
|
||||
@@ -142,7 +78,6 @@
|
||||
});
|
||||
|
||||
let searchTerm = $state('');
|
||||
let searchInputRef = $state<HTMLInputElement | null>(null);
|
||||
let highlightedIndex = $state<number>(-1);
|
||||
|
||||
let filteredOptions: ModelOption[] = $derived(
|
||||
@@ -157,13 +92,6 @@
|
||||
})()
|
||||
);
|
||||
|
||||
// Get indices of compatible options for keyboard navigation
|
||||
let compatibleIndices = $derived(
|
||||
filteredOptions
|
||||
.map((option, index) => (isModelCompatible(option) ? index : -1))
|
||||
.filter((i) => i !== -1)
|
||||
);
|
||||
|
||||
// Reset highlighted index when search term changes
|
||||
$effect(() => {
|
||||
void searchTerm;
|
||||
@@ -179,7 +107,7 @@
|
||||
});
|
||||
});
|
||||
|
||||
// Handle changes to the model selector pop-down or the model dialog, depending on if the server is in
|
||||
// Handle changes to the model selector dropdown or the model dialog, depending on if the server is in
|
||||
// router mode or not.
|
||||
function handleOpenChange(open: boolean) {
|
||||
if (loading || updating) return;
|
||||
@@ -190,11 +118,6 @@
|
||||
searchTerm = '';
|
||||
highlightedIndex = -1;
|
||||
|
||||
// Focus search input after popover opens
|
||||
tick().then(() => {
|
||||
requestAnimationFrame(() => searchInputRef?.focus());
|
||||
});
|
||||
|
||||
modelsStore.fetchRouterModels().then(() => {
|
||||
modelsStore.fetchModalitiesForLoadedModels();
|
||||
});
|
||||
@@ -215,36 +138,32 @@
|
||||
function handleSearchKeyDown(event: KeyboardEvent) {
|
||||
if (event.isComposing) return;
|
||||
|
||||
if (event.key === 'ArrowDown') {
|
||||
if (event.key === KeyboardKey.ARROW_DOWN) {
|
||||
event.preventDefault();
|
||||
if (compatibleIndices.length === 0) return;
|
||||
if (filteredOptions.length === 0) return;
|
||||
|
||||
const currentPos = compatibleIndices.indexOf(highlightedIndex);
|
||||
if (currentPos === -1 || currentPos === compatibleIndices.length - 1) {
|
||||
highlightedIndex = compatibleIndices[0];
|
||||
if (highlightedIndex === -1 || highlightedIndex === filteredOptions.length - 1) {
|
||||
highlightedIndex = 0;
|
||||
} else {
|
||||
highlightedIndex = compatibleIndices[currentPos + 1];
|
||||
highlightedIndex += 1;
|
||||
}
|
||||
} else if (event.key === 'ArrowUp') {
|
||||
} else if (event.key === KeyboardKey.ARROW_UP) {
|
||||
event.preventDefault();
|
||||
if (compatibleIndices.length === 0) return;
|
||||
if (filteredOptions.length === 0) return;
|
||||
|
||||
const currentPos = compatibleIndices.indexOf(highlightedIndex);
|
||||
if (currentPos === -1 || currentPos === 0) {
|
||||
highlightedIndex = compatibleIndices[compatibleIndices.length - 1];
|
||||
if (highlightedIndex === -1 || highlightedIndex === 0) {
|
||||
highlightedIndex = filteredOptions.length - 1;
|
||||
} else {
|
||||
highlightedIndex = compatibleIndices[currentPos - 1];
|
||||
highlightedIndex -= 1;
|
||||
}
|
||||
} else if (event.key === 'Enter') {
|
||||
} else if (event.key === KeyboardKey.ENTER) {
|
||||
event.preventDefault();
|
||||
if (highlightedIndex >= 0 && highlightedIndex < filteredOptions.length) {
|
||||
const option = filteredOptions[highlightedIndex];
|
||||
if (isModelCompatible(option)) {
|
||||
handleSelect(option.id);
|
||||
}
|
||||
} else if (compatibleIndices.length > 0) {
|
||||
// No selection - highlight first compatible option
|
||||
highlightedIndex = compatibleIndices[0];
|
||||
handleSelect(option.id);
|
||||
} else if (filteredOptions.length > 0) {
|
||||
// No selection - highlight first option
|
||||
highlightedIndex = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -347,68 +266,72 @@
|
||||
{@const selectedOption = getDisplayOption()}
|
||||
|
||||
{#if isRouter}
|
||||
<Popover.Root bind:open={isOpen} onOpenChange={handleOpenChange}>
|
||||
<Popover.Trigger
|
||||
class={cn(
|
||||
`inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-muted-foreground/10 px-1.5 py-1 text-xs transition hover:text-foreground focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60`,
|
||||
!isCurrentModelInCache()
|
||||
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
|
||||
: forceForegroundText
|
||||
? 'text-foreground'
|
||||
: isHighlightedCurrentModelActive
|
||||
? 'text-foreground'
|
||||
: 'text-muted-foreground',
|
||||
isOpen ? 'text-foreground' : ''
|
||||
)}
|
||||
style="max-width: min(calc(100cqw - 6.5rem), 32rem)"
|
||||
<DropdownMenu.Root bind:open={isOpen} onOpenChange={handleOpenChange}>
|
||||
<DropdownMenu.Trigger
|
||||
disabled={disabled || updating}
|
||||
onclick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
}}
|
||||
>
|
||||
<Package class="h-3.5 w-3.5" />
|
||||
<button
|
||||
type="button"
|
||||
class={cn(
|
||||
`inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-muted-foreground/10 px-1.5 py-1 text-xs transition hover:text-foreground focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60`,
|
||||
!isCurrentModelInCache()
|
||||
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
|
||||
: forceForegroundText
|
||||
? 'text-foreground'
|
||||
: isHighlightedCurrentModelActive
|
||||
? 'text-foreground'
|
||||
: 'text-muted-foreground',
|
||||
isOpen ? 'text-foreground' : ''
|
||||
)}
|
||||
style="max-width: min(calc(100cqw - 9rem), 20rem)"
|
||||
disabled={disabled || updating}
|
||||
>
|
||||
<Package class="h-3.5 w-3.5" />
|
||||
|
||||
<span class="truncate font-medium">
|
||||
{selectedOption?.model || 'Select model'}
|
||||
</span>
|
||||
<TruncatedText
|
||||
text={selectedOption?.model || 'Select model'}
|
||||
class="min-w-0 font-medium"
|
||||
/>
|
||||
|
||||
{#if updating}
|
||||
<Loader2 class="h-3 w-3.5 animate-spin" />
|
||||
{:else}
|
||||
<ChevronDown class="h-3 w-3.5" />
|
||||
{/if}
|
||||
</Popover.Trigger>
|
||||
{#if updating}
|
||||
<Loader2 class="h-3 w-3.5 animate-spin" />
|
||||
{:else}
|
||||
<ChevronDown class="h-3 w-3.5" />
|
||||
{/if}
|
||||
</button>
|
||||
</DropdownMenu.Trigger>
|
||||
|
||||
<Popover.Content
|
||||
class="group/popover-content w-96 max-w-[calc(100vw-2rem)] p-0"
|
||||
<DropdownMenu.Content
|
||||
align="end"
|
||||
sideOffset={8}
|
||||
collisionPadding={16}
|
||||
class="w-full max-w-[100vw] pt-0 sm:w-max sm:max-w-[calc(100vw-2rem)]"
|
||||
>
|
||||
<div class="flex max-h-[50dvh] flex-col overflow-hidden">
|
||||
<div
|
||||
class="order-1 shrink-0 border-b p-4 group-data-[side=top]/popover-content:order-2 group-data-[side=top]/popover-content:border-t group-data-[side=top]/popover-content:border-b-0"
|
||||
>
|
||||
<SearchInput
|
||||
id="model-search"
|
||||
placeholder="Search models..."
|
||||
bind:value={searchTerm}
|
||||
bind:ref={searchInputRef}
|
||||
onClose={() => handleOpenChange(false)}
|
||||
onKeyDown={handleSearchKeyDown}
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
class="models-list order-2 min-h-0 flex-1 overflow-y-auto group-data-[side=top]/popover-content:order-1"
|
||||
>
|
||||
<DropdownMenuSearchable
|
||||
bind:searchValue={searchTerm}
|
||||
placeholder="Search models..."
|
||||
onSearchKeyDown={handleSearchKeyDown}
|
||||
emptyMessage="No models found."
|
||||
isEmpty={filteredOptions.length === 0 && isCurrentModelInCache()}
|
||||
>
|
||||
<div class="models-list">
|
||||
{#if !isCurrentModelInCache() && currentModel}
|
||||
<!-- Show unavailable model as first option (disabled) -->
|
||||
<button
|
||||
type="button"
|
||||
class="flex w-full cursor-not-allowed items-center bg-red-400/10 px-4 py-2 text-left text-sm text-red-400"
|
||||
class="flex w-full cursor-not-allowed items-center bg-red-400/10 p-2 text-left text-sm text-red-400"
|
||||
role="option"
|
||||
aria-selected="true"
|
||||
aria-disabled="true"
|
||||
disabled
|
||||
>
|
||||
<span class="truncate">{selectedOption?.name || currentModel}</span>
|
||||
<span
|
||||
class="min-w-0 flex-1 truncate text-left sm:overflow-visible sm:text-clip sm:whitespace-nowrap"
|
||||
>
|
||||
{selectedOption?.name || currentModel}
|
||||
</span>
|
||||
<span class="ml-2 text-xs whitespace-nowrap opacity-70">(not available)</span>
|
||||
</button>
|
||||
<div class="my-1 h-px bg-border"></div>
|
||||
@@ -421,104 +344,78 @@
|
||||
{@const isLoaded = status === ServerModelStatus.LOADED}
|
||||
{@const isLoading = status === ServerModelStatus.LOADING}
|
||||
{@const isSelected = currentModel === option.model || activeId === option.id}
|
||||
{@const isCompatible = isModelCompatible(option)}
|
||||
{@const isHighlighted = index === highlightedIndex}
|
||||
{@const missingModalities = getMissingModalities(option)}
|
||||
|
||||
<div
|
||||
class={cn(
|
||||
'group flex w-full items-center gap-2 px-4 py-2 text-left text-sm transition focus:outline-none',
|
||||
isCompatible
|
||||
? 'cursor-pointer hover:bg-muted focus:bg-muted'
|
||||
: 'cursor-not-allowed opacity-50',
|
||||
'group flex w-full items-center gap-2 rounded-sm p-2 text-left text-sm transition focus:outline-none',
|
||||
'cursor-pointer hover:bg-muted focus:bg-muted',
|
||||
isSelected || isHighlighted
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: isCompatible
|
||||
? 'hover:bg-accent hover:text-accent-foreground'
|
||||
: '',
|
||||
: 'hover:bg-accent hover:text-accent-foreground',
|
||||
isLoaded ? 'text-popover-foreground' : 'text-muted-foreground'
|
||||
)}
|
||||
role="option"
|
||||
aria-selected={isSelected || isHighlighted}
|
||||
aria-disabled={!isCompatible}
|
||||
tabindex={isCompatible ? 0 : -1}
|
||||
onclick={() => isCompatible && handleSelect(option.id)}
|
||||
tabindex="0"
|
||||
onclick={() => handleSelect(option.id)}
|
||||
onmouseenter={() => (highlightedIndex = index)}
|
||||
onkeydown={(e) => {
|
||||
if (isCompatible && (e.key === 'Enter' || e.key === ' ')) {
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault();
|
||||
handleSelect(option.id);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<span class="min-w-0 flex-1 truncate">{option.model}</span>
|
||||
<span
|
||||
class="min-w-0 flex-1 truncate text-left sm:overflow-visible sm:pr-2 sm:text-clip sm:whitespace-nowrap"
|
||||
>
|
||||
{option.model}
|
||||
</span>
|
||||
|
||||
{#if missingModalities}
|
||||
<span class="flex shrink-0 items-center gap-1 text-muted-foreground/70">
|
||||
{#if missingModalities.vision}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<EyeOff class="h-3.5 w-3.5" />
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content class="z-[9999]">
|
||||
<p>No vision support</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
{#if missingModalities.audio}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<MicOff class="h-3.5 w-3.5" />
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content class="z-[9999]">
|
||||
<p>No audio support</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
{#if isLoading}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<Loader2 class="h-4 w-4 shrink-0 animate-spin text-muted-foreground" />
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content class="z-[9999]">
|
||||
<p>Loading model...</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else if isLoaded}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<button
|
||||
type="button"
|
||||
class="relative ml-2 flex h-4 w-4 shrink-0 items-center justify-center"
|
||||
onclick={(e) => {
|
||||
e.stopPropagation();
|
||||
modelsStore.unloadModel(option.model);
|
||||
}}
|
||||
>
|
||||
<span
|
||||
class="mr-2 h-2 w-2 rounded-full bg-green-500 transition-opacity group-hover:opacity-0"
|
||||
></span>
|
||||
<Power
|
||||
class="absolute mr-2 h-4 w-4 text-red-500 opacity-0 transition-opacity group-hover:opacity-100 hover:text-red-600"
|
||||
/>
|
||||
</button>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content class="z-[9999]">
|
||||
<p>Unload model</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
<span class="mx-2 h-2 w-2 rounded-full bg-muted-foreground/50"></span>
|
||||
{/if}
|
||||
<div class="flex w-6 shrink-0 justify-center">
|
||||
{#if isLoading}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<Loader2 class="h-4 w-4 animate-spin text-muted-foreground" />
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content class="z-[9999]">
|
||||
<p>Loading model...</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else if isLoaded}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<button
|
||||
type="button"
|
||||
class="relative flex h-4 w-4 items-center justify-center"
|
||||
onclick={(e) => {
|
||||
e.stopPropagation();
|
||||
modelsStore.unloadModel(option.model);
|
||||
}}
|
||||
>
|
||||
<span
|
||||
class="h-2 w-2 rounded-full bg-green-500 transition-opacity group-hover:opacity-0"
|
||||
></span>
|
||||
<Power
|
||||
class="absolute h-4 w-4 text-red-500 opacity-0 transition-opacity group-hover:opacity-100 hover:text-red-600"
|
||||
/>
|
||||
</button>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content class="z-[9999]">
|
||||
<p>Unload model</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
<span class="h-2 w-2 rounded-full bg-muted-foreground/50"></span>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
</Popover.Content>
|
||||
</Popover.Root>
|
||||
</DropdownMenuSearchable>
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
{:else}
|
||||
<button
|
||||
class={cn(
|
||||
@@ -538,9 +435,7 @@
|
||||
>
|
||||
<Package class="h-3.5 w-3.5" />
|
||||
|
||||
<span class="truncate font-medium">
|
||||
{selectedOption?.model}
|
||||
</span>
|
||||
<TruncatedText text={selectedOption?.model || ''} class="min-w-0 font-medium" />
|
||||
|
||||
{#if updating}
|
||||
<Loader2 class="h-3 w-3.5 animate-spin" />
|
||||
|
||||
73
tools/server/webui/src/lib/components/app/models/index.ts
Normal file
73
tools/server/webui/src/lib/components/app/models/index.ts
Normal file
@@ -0,0 +1,73 @@
|
||||
/**
|
||||
*
|
||||
* MODELS
|
||||
*
|
||||
* Components for model selection and display. Supports two server modes:
|
||||
* - **Single model mode**: Server runs with one model, selector shows model info
|
||||
* - **Router mode**: Server runs with multiple models, selector enables switching
|
||||
*
|
||||
* Integrates with modelsStore for model data and serverStore for mode detection.
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* **ModelsSelector** - Model selection dropdown
|
||||
*
|
||||
* Dropdown for selecting AI models with status indicators,
|
||||
* search, and model information display. Adapts UI based on server mode.
|
||||
*
|
||||
* **Architecture:**
|
||||
* - Uses DropdownMenuSearchable for model list
|
||||
* - Integrates with modelsStore for model options and selection
|
||||
* - Detects router vs single mode from serverStore
|
||||
* - Opens DialogModelInformation for model details
|
||||
*
|
||||
* **Features:**
|
||||
* - Searchable model list with keyboard navigation
|
||||
* - Model status indicators (loading/ready/error/updating)
|
||||
* - Model capabilities badges (vision, tools, etc.)
|
||||
* - Current/active model highlighting
|
||||
* - Model information dialog on info button click
|
||||
* - Router mode: shows all available models with status
|
||||
* - Single mode: shows current model name only
|
||||
* - Loading/updating skeleton states
|
||||
* - Global selection support for form integration
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <ModelsSelector
|
||||
* currentModel={conversation.modelId}
|
||||
* onModelChange={(id, name) => updateModel(id)}
|
||||
* useGlobalSelection
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
export { default as ModelsSelector } from './ModelsSelector.svelte';
|
||||
|
||||
/**
|
||||
* **ModelBadge** - Model name display badge
|
||||
*
|
||||
* Compact badge showing current model name with package icon.
|
||||
* Only visible in single model mode. Supports tooltip and copy functionality.
|
||||
*
|
||||
* **Architecture:**
|
||||
* - Reads model name from modelsStore or prop
|
||||
* - Checks server mode from serverStore
|
||||
* - Uses BadgeInfo for consistent styling
|
||||
*
|
||||
* **Features:**
|
||||
* - Optional copy to clipboard button
|
||||
* - Optional tooltip with model details
|
||||
* - Click handler for model info dialog
|
||||
* - Only renders in model mode (not router)
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <ModelBadge
|
||||
* onclick={() => showModelInfo = true}
|
||||
* showTooltip
|
||||
* showCopyIcon
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
export { default as ModelBadge } from './ModelBadge.svelte';
|
||||
@@ -8,7 +8,7 @@
|
||||
import { serverStore, serverLoading } from '$lib/stores/server.svelte';
|
||||
import { config, settingsStore } from '$lib/stores/settings.svelte';
|
||||
import { fade, fly, scale } from 'svelte/transition';
|
||||
import { KeyboardKey } from '$lib/enums/keyboard';
|
||||
import { KeyboardKey } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
export interface BinaryDetectionOptions {
|
||||
prefixLength: number;
|
||||
suspiciousCharThresholdRatio: number;
|
||||
maxAbsoluteNullBytes: number;
|
||||
}
|
||||
import type { BinaryDetectionOptions } from '$lib/types';
|
||||
|
||||
export const DEFAULT_BINARY_DETECTION_OPTIONS: BinaryDetectionOptions = {
|
||||
prefixLength: 1024 * 10, // Check the first 10KB of the string
|
||||
|
||||
33
tools/server/webui/src/lib/constants/cache.ts
Normal file
33
tools/server/webui/src/lib/constants/cache.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
/**
|
||||
* Cache configuration constants
|
||||
*/
|
||||
|
||||
/**
|
||||
* Default TTL (Time-To-Live) for cache entries in milliseconds.
|
||||
*/
|
||||
export const DEFAULT_CACHE_TTL_MS = 5 * 60 * 1000;
|
||||
|
||||
/**
|
||||
* Default maximum number of entries in a cache.
|
||||
*/
|
||||
export const DEFAULT_CACHE_MAX_ENTRIES = 100;
|
||||
|
||||
/**
|
||||
* TTL for model props cache in milliseconds.
|
||||
*/
|
||||
export const MODEL_PROPS_CACHE_TTL_MS = 10 * 60 * 1000;
|
||||
|
||||
/**
|
||||
* Maximum number of model props to cache.
|
||||
*/
|
||||
export const MODEL_PROPS_CACHE_MAX_ENTRIES = 50;
|
||||
|
||||
/**
|
||||
* Maximum number of inactive conversation states to keep in memory.
|
||||
*/
|
||||
export const MAX_INACTIVE_CONVERSATION_STATES = 10;
|
||||
|
||||
/**
|
||||
* Maximum age (in ms) for inactive conversation states before cleanup.
|
||||
*/
|
||||
export const INACTIVE_CONVERSATION_STATE_MAX_AGE_MS = 30 * 60 * 1000;
|
||||
@@ -1,6 +1 @@
|
||||
export const INPUT_CLASSES = `
|
||||
bg-muted/70 dark:bg-muted/85
|
||||
border border-border/30 focus-within:border-border dark:border-border/20 dark:focus-within:border-border
|
||||
outline-none
|
||||
text-foreground
|
||||
`;
|
||||
export { INPUT_CLASSES } from './css-classes';
|
||||
|
||||
14
tools/server/webui/src/lib/constants/settings-sections.ts
Normal file
14
tools/server/webui/src/lib/constants/settings-sections.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
/**
|
||||
* Settings section titles constants for ChatSettings component.
|
||||
*/
|
||||
export const SETTINGS_SECTION_TITLES = {
|
||||
GENERAL: 'General',
|
||||
DISPLAY: 'Display',
|
||||
SAMPLING: 'Sampling',
|
||||
PENALTIES: 'Penalties',
|
||||
IMPORT_EXPORT: 'Import/Export',
|
||||
DEVELOPER: 'Developer'
|
||||
} as const;
|
||||
|
||||
export type SettingsSectionTitle =
|
||||
(typeof SETTINGS_SECTION_TITLES)[keyof typeof SETTINGS_SECTION_TITLES];
|
||||
@@ -1,6 +1,13 @@
|
||||
export { AttachmentType } from './attachment';
|
||||
|
||||
export { ChatMessageStatsView } from './chat';
|
||||
export {
|
||||
ChatMessageStatsView,
|
||||
ReasoningFormat,
|
||||
MessageRole,
|
||||
MessageType,
|
||||
ContentPartType,
|
||||
ErrorDialogType
|
||||
} from './chat';
|
||||
|
||||
export {
|
||||
FileTypeCategory,
|
||||
@@ -21,3 +28,9 @@ export {
|
||||
export { ModelModality } from './model';
|
||||
|
||||
export { ServerRole, ServerModelStatus } from './server';
|
||||
|
||||
export { ParameterSource, SyncableParameterType, SettingsFieldType } from './settings';
|
||||
|
||||
export { KeyboardKey } from './keyboard';
|
||||
|
||||
export { UrlPrefix } from './ui';
|
||||
|
||||
10
tools/server/webui/src/lib/enums/ui.ts
Normal file
10
tools/server/webui/src/lib/enums/ui.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
/**
|
||||
* URL prefixes for protocol detection.
|
||||
*/
|
||||
export enum UrlPrefix {
|
||||
DATA = 'data:',
|
||||
HTTP = 'http://',
|
||||
HTTPS = 'https://',
|
||||
WEBSOCKET = 'ws://',
|
||||
WEBSOCKET_SECURE = 'wss://'
|
||||
}
|
||||
@@ -1,26 +1,21 @@
|
||||
import { modelsStore } from '$lib/stores/models.svelte';
|
||||
import { isRouterMode } from '$lib/stores/server.svelte';
|
||||
import { toast } from 'svelte-sonner';
|
||||
import type { ModelModalities } from '$lib/types';
|
||||
|
||||
interface UseModelChangeValidationOptions {
|
||||
/**
|
||||
* Function to get required modalities for validation.
|
||||
* For ChatForm: () => usedModalities() - all messages
|
||||
* For ChatMessageAssistant: () => getModalitiesUpToMessage(messageId) - messages before
|
||||
*/
|
||||
getRequiredModalities: () => ModelModalities;
|
||||
|
||||
/**
|
||||
* Optional callback to execute after successful validation.
|
||||
* For ChatForm: undefined - just select model
|
||||
* For ChatMessageAssistant: (modelName) => onRegenerate(modelName)
|
||||
*/
|
||||
onSuccess?: (modelName: string) => void;
|
||||
|
||||
/**
|
||||
* Optional callback for rollback on validation failure.
|
||||
* For ChatForm: (previousId) => selectModelById(previousId)
|
||||
* For ChatMessageAssistant: undefined - no rollback needed
|
||||
*/
|
||||
onValidationFailure?: (previousModelId: string | null) => Promise<void>;
|
||||
}
|
||||
@@ -33,12 +28,10 @@ export function useModelChangeValidation(options: UseModelChangeValidationOption
|
||||
|
||||
async function handleModelChange(modelId: string, modelName: string): Promise<boolean> {
|
||||
try {
|
||||
// Store previous selection for potential rollback
|
||||
if (onValidationFailure) {
|
||||
previousSelectedModelId = modelsStore.selectedModelId;
|
||||
}
|
||||
|
||||
// Load model if not already loaded (router mode only)
|
||||
let hasLoadedModel = false;
|
||||
const isModelLoadedBefore = modelsStore.isModelLoaded(modelName);
|
||||
|
||||
@@ -52,13 +45,11 @@ export function useModelChangeValidation(options: UseModelChangeValidationOption
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch model props to validate modalities
|
||||
const props = await modelsStore.fetchModelProps(modelName);
|
||||
|
||||
if (props?.modalities) {
|
||||
const requiredModalities = getRequiredModalities();
|
||||
|
||||
// Check if model supports required modalities
|
||||
const missingModalities: string[] = [];
|
||||
if (requiredModalities.vision && !props.modalities.vision) {
|
||||
missingModalities.push('vision');
|
||||
@@ -72,7 +63,6 @@ export function useModelChangeValidation(options: UseModelChangeValidationOption
|
||||
`Model "${modelName}" doesn't support required modalities: ${missingModalities.join(', ')}. Please select a different model.`
|
||||
);
|
||||
|
||||
// Unload the model if we just loaded it
|
||||
if (isRouter && hasLoadedModel) {
|
||||
try {
|
||||
await modelsStore.unloadModel(modelName);
|
||||
@@ -81,7 +71,6 @@ export function useModelChangeValidation(options: UseModelChangeValidationOption
|
||||
}
|
||||
}
|
||||
|
||||
// Execute rollback callback if provided
|
||||
if (onValidationFailure && previousSelectedModelId) {
|
||||
await onValidationFailure(previousSelectedModelId);
|
||||
}
|
||||
@@ -90,10 +79,8 @@ export function useModelChangeValidation(options: UseModelChangeValidationOption
|
||||
}
|
||||
}
|
||||
|
||||
// Select the model (validation passed)
|
||||
await modelsStore.selectModelById(modelId);
|
||||
|
||||
// Execute success callback if provided
|
||||
if (onSuccess) {
|
||||
onSuccess(modelName);
|
||||
}
|
||||
@@ -103,7 +90,6 @@ export function useModelChangeValidation(options: UseModelChangeValidationOption
|
||||
console.error('Failed to change model:', error);
|
||||
toast.error('Failed to validate model capabilities');
|
||||
|
||||
// Execute rollback callback on error if provided
|
||||
if (onValidationFailure && previousSelectedModelId) {
|
||||
await onValidationFailure(previousSelectedModelId);
|
||||
}
|
||||
|
||||
@@ -1,21 +1,7 @@
|
||||
import { activeProcessingState } from '$lib/stores/chat.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { STATS_UNITS } from '$lib/constants/processing-info';
|
||||
import type { ApiProcessingState } from '$lib/types';
|
||||
|
||||
interface LiveProcessingStats {
|
||||
tokensProcessed: number;
|
||||
totalTokens: number;
|
||||
timeMs: number;
|
||||
tokensPerSecond: number;
|
||||
etaSecs?: number;
|
||||
}
|
||||
|
||||
interface LiveGenerationStats {
|
||||
tokensGenerated: number;
|
||||
timeMs: number;
|
||||
tokensPerSecond: number;
|
||||
}
|
||||
import type { ApiProcessingState, LiveProcessingStats, LiveGenerationStats } from '$lib/types';
|
||||
|
||||
export interface UseProcessingStateReturn {
|
||||
readonly processingState: ApiProcessingState | null;
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
import type { Root as HastRoot } from 'hast';
|
||||
import { visit } from 'unist-util-visit';
|
||||
import type { DatabaseMessageExtra, DatabaseMessageExtraImageFile } from '$lib/types/database';
|
||||
import { AttachmentType, UrlPrefix } from '$lib/enums';
|
||||
|
||||
/**
|
||||
* Rehype plugin to resolve attachment image sources.
|
||||
* Converts attachment names to base64 data URLs.
|
||||
*/
|
||||
export function rehypeResolveAttachmentImages(options: { attachments?: DatabaseMessageExtra[] }) {
|
||||
return (tree: HastRoot) => {
|
||||
visit(tree, 'element', (node) => {
|
||||
if (node.tagName === 'img' && node.properties?.src) {
|
||||
const src = String(node.properties.src);
|
||||
|
||||
if (src.startsWith(UrlPrefix.DATA) || src.startsWith(UrlPrefix.HTTP)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const attachment = options.attachments?.find(
|
||||
(a): a is DatabaseMessageExtraImageFile =>
|
||||
a.type === AttachmentType.IMAGE && a.name === src
|
||||
);
|
||||
|
||||
if (attachment?.base64Url) {
|
||||
node.properties.src = attachment.base64Url;
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
}
|
||||
@@ -17,7 +17,7 @@ class LlamacppDatabase extends Dexie {
|
||||
|
||||
const db = new LlamacppDatabase();
|
||||
import { v4 as uuid } from 'uuid';
|
||||
import { MessageRole } from '$lib/enums/chat';
|
||||
import { MessageRole } from '$lib/enums';
|
||||
|
||||
export class DatabaseService {
|
||||
/**
|
||||
|
||||
@@ -1,400 +0,0 @@
|
||||
import Dexie, { type EntityTable } from 'dexie';
|
||||
import { findDescendantMessages } from '$lib/utils';
|
||||
|
||||
class LlamacppDatabase extends Dexie {
|
||||
conversations!: EntityTable<DatabaseConversation, string>;
|
||||
messages!: EntityTable<DatabaseMessage, string>;
|
||||
|
||||
constructor() {
|
||||
super('LlamacppWebui');
|
||||
|
||||
this.version(1).stores({
|
||||
conversations: 'id, lastModified, currNode, name',
|
||||
messages: 'id, convId, type, role, timestamp, parent, children'
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const db = new LlamacppDatabase();
|
||||
import { v4 as uuid } from 'uuid';
|
||||
|
||||
/**
|
||||
* DatabaseService - Stateless IndexedDB communication layer
|
||||
*
|
||||
* **Terminology - Chat vs Conversation:**
|
||||
* - **Chat**: The active interaction space with the Chat Completions API (ephemeral, runtime).
|
||||
* - **Conversation**: The persistent database entity storing all messages and metadata.
|
||||
* This service handles raw database operations for conversations - the lowest layer
|
||||
* in the persistence stack.
|
||||
*
|
||||
* This service provides a stateless data access layer built on IndexedDB using Dexie ORM.
|
||||
* It handles all low-level storage operations for conversations and messages with support
|
||||
* for complex branching and message threading. All methods are static - no instance state.
|
||||
*
|
||||
* **Architecture & Relationships (bottom to top):**
|
||||
* - **DatabaseService** (this class): Stateless IndexedDB operations
|
||||
* - Lowest layer - direct Dexie/IndexedDB communication
|
||||
* - Pure CRUD operations without business logic
|
||||
* - Handles branching tree structure (parent-child relationships)
|
||||
* - Provides transaction safety for multi-table operations
|
||||
*
|
||||
* - **ConversationsService**: Stateless business logic layer
|
||||
* - Uses DatabaseService for all persistence operations
|
||||
* - Adds import/export, navigation, and higher-level operations
|
||||
*
|
||||
* - **conversationsStore**: Reactive state management for conversations
|
||||
* - Uses ConversationsService for database operations
|
||||
* - Manages conversation list, active conversation, and messages in memory
|
||||
*
|
||||
* - **chatStore**: Active AI interaction management
|
||||
* - Uses conversationsStore for conversation context
|
||||
* - Directly uses DatabaseService for message CRUD during streaming
|
||||
*
|
||||
* **Key Features:**
|
||||
* - **Conversation CRUD**: Create, read, update, delete conversations
|
||||
* - **Message CRUD**: Add, update, delete messages with branching support
|
||||
* - **Branch Operations**: Create branches, find descendants, cascade deletions
|
||||
* - **Transaction Safety**: Atomic operations for data consistency
|
||||
*
|
||||
* **Database Schema:**
|
||||
* - `conversations`: id, lastModified, currNode, name
|
||||
* - `messages`: id, convId, type, role, timestamp, parent, children
|
||||
*
|
||||
* **Branching Model:**
|
||||
* Messages form a tree structure where each message can have multiple children,
|
||||
* enabling conversation branching and alternative response paths. The conversation's
|
||||
* `currNode` tracks the currently active branch endpoint.
|
||||
*/
|
||||
export class DatabaseService {
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Conversations
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Creates a new conversation.
|
||||
*
|
||||
* @param name - Name of the conversation
|
||||
* @returns The created conversation
|
||||
*/
|
||||
static async createConversation(name: string): Promise<DatabaseConversation> {
|
||||
const conversation: DatabaseConversation = {
|
||||
id: uuid(),
|
||||
name,
|
||||
lastModified: Date.now(),
|
||||
currNode: ''
|
||||
};
|
||||
|
||||
await db.conversations.add(conversation);
|
||||
return conversation;
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Messages
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Creates a new message branch by adding a message and updating parent/child relationships.
|
||||
* Also updates the conversation's currNode to point to the new message.
|
||||
*
|
||||
* @param message - Message to add (without id)
|
||||
* @param parentId - Parent message ID to attach to
|
||||
* @returns The created message
|
||||
*/
|
||||
static async createMessageBranch(
|
||||
message: Omit<DatabaseMessage, 'id'>,
|
||||
parentId: string | null
|
||||
): Promise<DatabaseMessage> {
|
||||
return await db.transaction('rw', [db.conversations, db.messages], async () => {
|
||||
// Handle null parent (root message case)
|
||||
if (parentId !== null) {
|
||||
const parentMessage = await db.messages.get(parentId);
|
||||
if (!parentMessage) {
|
||||
throw new Error(`Parent message ${parentId} not found`);
|
||||
}
|
||||
}
|
||||
|
||||
const newMessage: DatabaseMessage = {
|
||||
...message,
|
||||
id: uuid(),
|
||||
parent: parentId,
|
||||
toolCalls: message.toolCalls ?? '',
|
||||
children: []
|
||||
};
|
||||
|
||||
await db.messages.add(newMessage);
|
||||
|
||||
// Update parent's children array if parent exists
|
||||
if (parentId !== null) {
|
||||
const parentMessage = await db.messages.get(parentId);
|
||||
if (parentMessage) {
|
||||
await db.messages.update(parentId, {
|
||||
children: [...parentMessage.children, newMessage.id]
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
await this.updateConversation(message.convId, {
|
||||
currNode: newMessage.id
|
||||
});
|
||||
|
||||
return newMessage;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a root message for a new conversation.
|
||||
* Root messages are not displayed but serve as the tree root for branching.
|
||||
*
|
||||
* @param convId - Conversation ID
|
||||
* @returns The created root message
|
||||
*/
|
||||
static async createRootMessage(convId: string): Promise<string> {
|
||||
const rootMessage: DatabaseMessage = {
|
||||
id: uuid(),
|
||||
convId,
|
||||
type: 'root',
|
||||
timestamp: Date.now(),
|
||||
role: 'system',
|
||||
content: '',
|
||||
parent: null,
|
||||
thinking: '',
|
||||
toolCalls: '',
|
||||
children: []
|
||||
};
|
||||
|
||||
await db.messages.add(rootMessage);
|
||||
return rootMessage.id;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a system prompt message for a conversation.
|
||||
*
|
||||
* @param convId - Conversation ID
|
||||
* @param systemPrompt - The system prompt content (must be non-empty)
|
||||
* @param parentId - Parent message ID (typically the root message)
|
||||
* @returns The created system message
|
||||
* @throws Error if systemPrompt is empty
|
||||
*/
|
||||
static async createSystemMessage(
|
||||
convId: string,
|
||||
systemPrompt: string,
|
||||
parentId: string
|
||||
): Promise<DatabaseMessage> {
|
||||
const trimmedPrompt = systemPrompt.trim();
|
||||
if (!trimmedPrompt) {
|
||||
throw new Error('Cannot create system message with empty content');
|
||||
}
|
||||
|
||||
const systemMessage: DatabaseMessage = {
|
||||
id: uuid(),
|
||||
convId,
|
||||
type: 'system',
|
||||
timestamp: Date.now(),
|
||||
role: 'system',
|
||||
content: trimmedPrompt,
|
||||
parent: parentId,
|
||||
thinking: '',
|
||||
children: []
|
||||
};
|
||||
|
||||
await db.messages.add(systemMessage);
|
||||
|
||||
const parentMessage = await db.messages.get(parentId);
|
||||
if (parentMessage) {
|
||||
await db.messages.update(parentId, {
|
||||
children: [...parentMessage.children, systemMessage.id]
|
||||
});
|
||||
}
|
||||
|
||||
return systemMessage;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a conversation and all its messages.
|
||||
*
|
||||
* @param id - Conversation ID
|
||||
*/
|
||||
static async deleteConversation(id: string): Promise<void> {
|
||||
await db.transaction('rw', [db.conversations, db.messages], async () => {
|
||||
await db.conversations.delete(id);
|
||||
await db.messages.where('convId').equals(id).delete();
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a message and removes it from its parent's children array.
|
||||
*
|
||||
* @param messageId - ID of the message to delete
|
||||
*/
|
||||
static async deleteMessage(messageId: string): Promise<void> {
|
||||
await db.transaction('rw', db.messages, async () => {
|
||||
const message = await db.messages.get(messageId);
|
||||
if (!message) return;
|
||||
|
||||
// Remove this message from its parent's children array
|
||||
if (message.parent) {
|
||||
const parent = await db.messages.get(message.parent);
|
||||
if (parent) {
|
||||
parent.children = parent.children.filter((childId: string) => childId !== messageId);
|
||||
await db.messages.put(parent);
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the message
|
||||
await db.messages.delete(messageId);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a message and all its descendant messages (cascading deletion).
|
||||
* This removes the entire branch starting from the specified message.
|
||||
*
|
||||
* @param conversationId - ID of the conversation containing the message
|
||||
* @param messageId - ID of the root message to delete (along with all descendants)
|
||||
* @returns Array of all deleted message IDs
|
||||
*/
|
||||
static async deleteMessageCascading(
|
||||
conversationId: string,
|
||||
messageId: string
|
||||
): Promise<string[]> {
|
||||
return await db.transaction('rw', db.messages, async () => {
|
||||
// Get all messages in the conversation to find descendants
|
||||
const allMessages = await db.messages.where('convId').equals(conversationId).toArray();
|
||||
|
||||
// Find all descendant messages
|
||||
const descendants = findDescendantMessages(allMessages, messageId);
|
||||
const allToDelete = [messageId, ...descendants];
|
||||
|
||||
// Get the message to delete for parent cleanup
|
||||
const message = await db.messages.get(messageId);
|
||||
if (message && message.parent) {
|
||||
const parent = await db.messages.get(message.parent);
|
||||
if (parent) {
|
||||
parent.children = parent.children.filter((childId: string) => childId !== messageId);
|
||||
await db.messages.put(parent);
|
||||
}
|
||||
}
|
||||
|
||||
// Delete all messages in the branch
|
||||
await db.messages.bulkDelete(allToDelete);
|
||||
|
||||
return allToDelete;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all conversations, sorted by last modified time (newest first).
|
||||
*
|
||||
* @returns Array of conversations
|
||||
*/
|
||||
static async getAllConversations(): Promise<DatabaseConversation[]> {
|
||||
return await db.conversations.orderBy('lastModified').reverse().toArray();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a conversation by ID.
|
||||
*
|
||||
* @param id - Conversation ID
|
||||
* @returns The conversation if found, otherwise undefined
|
||||
*/
|
||||
static async getConversation(id: string): Promise<DatabaseConversation | undefined> {
|
||||
return await db.conversations.get(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all messages in a conversation, sorted by timestamp (oldest first).
|
||||
*
|
||||
* @param convId - Conversation ID
|
||||
* @returns Array of messages in the conversation
|
||||
*/
|
||||
static async getConversationMessages(convId: string): Promise<DatabaseMessage[]> {
|
||||
return await db.messages.where('convId').equals(convId).sortBy('timestamp');
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a conversation.
|
||||
*
|
||||
* @param id - Conversation ID
|
||||
* @param updates - Partial updates to apply
|
||||
* @returns Promise that resolves when the conversation is updated
|
||||
*/
|
||||
static async updateConversation(
|
||||
id: string,
|
||||
updates: Partial<Omit<DatabaseConversation, 'id'>>
|
||||
): Promise<void> {
|
||||
await db.conversations.update(id, {
|
||||
...updates,
|
||||
lastModified: Date.now()
|
||||
});
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Navigation
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Updates the conversation's current node (active branch).
|
||||
* This determines which conversation path is currently being viewed.
|
||||
*
|
||||
* @param convId - Conversation ID
|
||||
* @param nodeId - Message ID to set as current node
|
||||
*/
|
||||
static async updateCurrentNode(convId: string, nodeId: string): Promise<void> {
|
||||
await this.updateConversation(convId, {
|
||||
currNode: nodeId
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a message.
|
||||
*
|
||||
* @param id - Message ID
|
||||
* @param updates - Partial updates to apply
|
||||
* @returns Promise that resolves when the message is updated
|
||||
*/
|
||||
static async updateMessage(
|
||||
id: string,
|
||||
updates: Partial<Omit<DatabaseMessage, 'id'>>
|
||||
): Promise<void> {
|
||||
await db.messages.update(id, updates);
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Import
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Imports multiple conversations and their messages.
|
||||
* Skips conversations that already exist.
|
||||
*
|
||||
* @param data - Array of { conv, messages } objects
|
||||
*/
|
||||
static async importConversations(
|
||||
data: { conv: DatabaseConversation; messages: DatabaseMessage[] }[]
|
||||
): Promise<{ imported: number; skipped: number }> {
|
||||
let importedCount = 0;
|
||||
let skippedCount = 0;
|
||||
|
||||
return await db.transaction('rw', [db.conversations, db.messages], async () => {
|
||||
for (const item of data) {
|
||||
const { conv, messages } = item;
|
||||
|
||||
const existing = await db.conversations.get(conv.id);
|
||||
if (existing) {
|
||||
console.warn(`Conversation "${conv.name}" already exists, skipping...`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
await db.conversations.add(conv);
|
||||
for (const msg of messages) {
|
||||
await db.messages.put(msg);
|
||||
}
|
||||
|
||||
importedCount++;
|
||||
}
|
||||
|
||||
return { imported: importedCount, skipped: skippedCount };
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
export { ChatService } from './chat';
|
||||
export { DatabaseService } from './database';
|
||||
export { ModelsService } from './models';
|
||||
export { PropsService } from './props';
|
||||
export { ParameterSyncService } from './parameter-sync';
|
||||
export { DatabaseService } from './database.service';
|
||||
export { ModelsService } from './models.service';
|
||||
export { PropsService } from './props.service';
|
||||
export { ParameterSyncService, SYNCABLE_PARAMETERS } from './parameter-sync.service';
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { ServerModelStatus } from '$lib/enums';
|
||||
import { apiFetch, apiPost } from '$lib/utils/api-fetch';
|
||||
import { apiFetch, apiPost } from '$lib/utils';
|
||||
|
||||
export class ModelsService {
|
||||
/**
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
import { base } from '$app/paths';
|
||||
import { ServerModelStatus } from '$lib/enums';
|
||||
import { getJsonHeaders } from '$lib/utils';
|
||||
|
||||
/**
|
||||
* ModelsService - Stateless service for model management API communication
|
||||
*
|
||||
* This service handles communication with model-related endpoints:
|
||||
* - `/v1/models` - OpenAI-compatible model list (MODEL + ROUTER mode)
|
||||
* - `/models/load`, `/models/unload` - Router-specific model management (ROUTER mode only)
|
||||
*
|
||||
* **Responsibilities:**
|
||||
* - List available models
|
||||
* - Load/unload models (ROUTER mode)
|
||||
* - Check model status (ROUTER mode)
|
||||
*
|
||||
* **Used by:**
|
||||
* - modelsStore: Primary consumer for model state management
|
||||
*/
|
||||
export class ModelsService {
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Listing
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Fetch list of models from OpenAI-compatible endpoint
|
||||
* Works in both MODEL and ROUTER modes
|
||||
*/
|
||||
static async list(): Promise<ApiModelListResponse> {
|
||||
const response = await fetch(`${base}/v1/models`, {
|
||||
headers: getJsonHeaders()
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch model list (status ${response.status})`);
|
||||
}
|
||||
|
||||
return response.json() as Promise<ApiModelListResponse>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch list of all models with detailed metadata (ROUTER mode)
|
||||
* Returns models with load status, paths, and other metadata
|
||||
*/
|
||||
static async listRouter(): Promise<ApiRouterModelsListResponse> {
|
||||
const response = await fetch(`${base}/v1/models`, {
|
||||
headers: getJsonHeaders()
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch router models list (status ${response.status})`);
|
||||
}
|
||||
|
||||
return response.json() as Promise<ApiRouterModelsListResponse>;
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Load/Unload
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Load a model (ROUTER mode)
|
||||
* POST /models/load
|
||||
* @param modelId - Model identifier to load
|
||||
* @param extraArgs - Optional additional arguments to pass to the model instance
|
||||
*/
|
||||
static async load(modelId: string, extraArgs?: string[]): Promise<ApiRouterModelsLoadResponse> {
|
||||
const payload: { model: string; extra_args?: string[] } = { model: modelId };
|
||||
if (extraArgs && extraArgs.length > 0) {
|
||||
payload.extra_args = extraArgs;
|
||||
}
|
||||
|
||||
const response = await fetch(`${base}/models/load`, {
|
||||
method: 'POST',
|
||||
headers: getJsonHeaders(),
|
||||
body: JSON.stringify(payload)
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
throw new Error(errorData.error || `Failed to load model (status ${response.status})`);
|
||||
}
|
||||
|
||||
return response.json() as Promise<ApiRouterModelsLoadResponse>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Unload a model (ROUTER mode)
|
||||
* POST /models/unload
|
||||
* @param modelId - Model identifier to unload
|
||||
*/
|
||||
static async unload(modelId: string): Promise<ApiRouterModelsUnloadResponse> {
|
||||
const response = await fetch(`${base}/models/unload`, {
|
||||
method: 'POST',
|
||||
headers: getJsonHeaders(),
|
||||
body: JSON.stringify({ model: modelId })
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
throw new Error(errorData.error || `Failed to unload model (status ${response.status})`);
|
||||
}
|
||||
|
||||
return response.json() as Promise<ApiRouterModelsUnloadResponse>;
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Status
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Check if a model is loaded based on its metadata
|
||||
*/
|
||||
static isModelLoaded(model: ApiModelDataEntry): boolean {
|
||||
return model.status.value === ServerModelStatus.LOADED;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model is currently loading
|
||||
*/
|
||||
static isModelLoading(model: ApiModelDataEntry): boolean {
|
||||
return model.status.value === ServerModelStatus.LOADING;
|
||||
}
|
||||
}
|
||||
@@ -1,22 +1,6 @@
|
||||
import { normalizeFloatingPoint } from '$lib/utils';
|
||||
import { SyncableParameterType, ParameterSource } from '$lib/enums/settings';
|
||||
|
||||
type ParameterValue = string | number | boolean;
|
||||
type ParameterRecord = Record<string, ParameterValue>;
|
||||
|
||||
interface ParameterInfo {
|
||||
value: string | number | boolean;
|
||||
source: ParameterSource;
|
||||
serverDefault?: string | number | boolean;
|
||||
userOverride?: string | number | boolean;
|
||||
}
|
||||
|
||||
interface SyncableParameter {
|
||||
key: string;
|
||||
serverKey: string;
|
||||
type: SyncableParameterType;
|
||||
canSync: boolean;
|
||||
}
|
||||
import type { SyncableParameter, ParameterRecord, ParameterInfo, ParameterValue } from '$lib/types';
|
||||
import { SyncableParameterType, ParameterSource } from '$lib/enums';
|
||||
|
||||
/**
|
||||
* Mapping of webui setting keys to server parameter keys.
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { ParameterSyncService } from './parameter-sync';
|
||||
|
||||
describe('ParameterSyncService', () => {
|
||||
describe('roundFloatingPoint', () => {
|
||||
it('should fix JavaScript floating-point precision issues', () => {
|
||||
// Test the specific values from the screenshot
|
||||
const mockServerParams = {
|
||||
top_p: 0.949999988079071,
|
||||
min_p: 0.009999999776482582,
|
||||
temperature: 0.800000011920929,
|
||||
top_k: 40,
|
||||
samplers: ['top_k', 'typ_p', 'top_p', 'min_p', 'temperature']
|
||||
};
|
||||
|
||||
const result = ParameterSyncService.extractServerDefaults({
|
||||
...mockServerParams,
|
||||
// Add other required fields to match the API type
|
||||
n_predict: 512,
|
||||
seed: -1,
|
||||
dynatemp_range: 0.0,
|
||||
dynatemp_exponent: 1.0,
|
||||
xtc_probability: 0.0,
|
||||
xtc_threshold: 0.1,
|
||||
typ_p: 1.0,
|
||||
repeat_last_n: 64,
|
||||
repeat_penalty: 1.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
dry_multiplier: 0.0,
|
||||
dry_base: 1.75,
|
||||
dry_allowed_length: 2,
|
||||
dry_penalty_last_n: -1,
|
||||
mirostat: 0,
|
||||
mirostat_tau: 5.0,
|
||||
mirostat_eta: 0.1,
|
||||
stop: [],
|
||||
max_tokens: -1,
|
||||
n_keep: 0,
|
||||
n_discard: 0,
|
||||
ignore_eos: false,
|
||||
stream: true,
|
||||
logit_bias: [],
|
||||
n_probs: 0,
|
||||
min_keep: 0,
|
||||
grammar: '',
|
||||
grammar_lazy: false,
|
||||
grammar_triggers: [],
|
||||
preserved_tokens: [],
|
||||
chat_format: '',
|
||||
reasoning_format: '',
|
||||
reasoning_in_content: false,
|
||||
thinking_forced_open: false,
|
||||
'speculative.n_max': 0,
|
||||
'speculative.n_min': 0,
|
||||
'speculative.p_min': 0.0,
|
||||
timings_per_token: false,
|
||||
post_sampling_probs: false,
|
||||
lora: [],
|
||||
top_n_sigma: 0.0,
|
||||
dry_sequence_breakers: []
|
||||
} as ApiLlamaCppServerProps['default_generation_settings']['params']);
|
||||
|
||||
// Check that the problematic floating-point values are rounded correctly
|
||||
expect(result.top_p).toBe(0.95);
|
||||
expect(result.min_p).toBe(0.01);
|
||||
expect(result.temperature).toBe(0.8);
|
||||
expect(result.top_k).toBe(40); // Integer should remain unchanged
|
||||
expect(result.samplers).toBe('top_k;typ_p;top_p;min_p;temperature');
|
||||
});
|
||||
|
||||
it('should preserve non-numeric values', () => {
|
||||
const mockServerParams = {
|
||||
samplers: ['top_k', 'temperature'],
|
||||
max_tokens: -1,
|
||||
temperature: 0.7
|
||||
};
|
||||
|
||||
const result = ParameterSyncService.extractServerDefaults({
|
||||
...mockServerParams,
|
||||
// Minimal required fields
|
||||
n_predict: 512,
|
||||
seed: -1,
|
||||
dynatemp_range: 0.0,
|
||||
dynatemp_exponent: 1.0,
|
||||
top_k: 40,
|
||||
top_p: 0.95,
|
||||
min_p: 0.05,
|
||||
xtc_probability: 0.0,
|
||||
xtc_threshold: 0.1,
|
||||
typ_p: 1.0,
|
||||
repeat_last_n: 64,
|
||||
repeat_penalty: 1.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
dry_multiplier: 0.0,
|
||||
dry_base: 1.75,
|
||||
dry_allowed_length: 2,
|
||||
dry_penalty_last_n: -1,
|
||||
mirostat: 0,
|
||||
mirostat_tau: 5.0,
|
||||
mirostat_eta: 0.1,
|
||||
stop: [],
|
||||
n_keep: 0,
|
||||
n_discard: 0,
|
||||
ignore_eos: false,
|
||||
stream: true,
|
||||
logit_bias: [],
|
||||
n_probs: 0,
|
||||
min_keep: 0,
|
||||
grammar: '',
|
||||
grammar_lazy: false,
|
||||
grammar_triggers: [],
|
||||
preserved_tokens: [],
|
||||
chat_format: '',
|
||||
reasoning_format: '',
|
||||
reasoning_in_content: false,
|
||||
thinking_forced_open: false,
|
||||
'speculative.n_max': 0,
|
||||
'speculative.n_min': 0,
|
||||
'speculative.p_min': 0.0,
|
||||
timings_per_token: false,
|
||||
post_sampling_probs: false,
|
||||
lora: [],
|
||||
top_n_sigma: 0.0,
|
||||
dry_sequence_breakers: []
|
||||
} as ApiLlamaCppServerProps['default_generation_settings']['params']);
|
||||
|
||||
expect(result.samplers).toBe('top_k;temperature');
|
||||
expect(result.max_tokens).toBe(-1);
|
||||
expect(result.temperature).toBe(0.7);
|
||||
});
|
||||
|
||||
it('should merge webui settings from props when provided', () => {
|
||||
const result = ParameterSyncService.extractServerDefaults(null, {
|
||||
pasteLongTextToFileLen: 0,
|
||||
pdfAsImage: true,
|
||||
renderUserContentAsMarkdown: false,
|
||||
theme: 'dark'
|
||||
});
|
||||
|
||||
expect(result.pasteLongTextToFileLen).toBe(0);
|
||||
expect(result.pdfAsImage).toBe(true);
|
||||
expect(result.renderUserContentAsMarkdown).toBe(false);
|
||||
expect(result.theme).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,273 +0,0 @@
|
||||
/**
|
||||
* ParameterSyncService - Handles synchronization between server defaults and user settings
|
||||
*
|
||||
* This service manages the complex logic of merging server-provided default parameters
|
||||
* with user-configured overrides, ensuring the UI reflects the actual server state
|
||||
* while preserving user customizations.
|
||||
*
|
||||
* **Key Responsibilities:**
|
||||
* - Extract syncable parameters from server props
|
||||
* - Merge server defaults with user overrides
|
||||
* - Track parameter sources (server, user, default)
|
||||
* - Provide sync utilities for settings store integration
|
||||
*/
|
||||
|
||||
import { normalizeFloatingPoint } from '$lib/utils';
|
||||
|
||||
export type ParameterSource = 'default' | 'custom';
|
||||
export type ParameterValue = string | number | boolean;
|
||||
export type ParameterRecord = Record<string, ParameterValue>;
|
||||
|
||||
export interface ParameterInfo {
|
||||
value: string | number | boolean;
|
||||
source: ParameterSource;
|
||||
serverDefault?: string | number | boolean;
|
||||
userOverride?: string | number | boolean;
|
||||
}
|
||||
|
||||
export interface SyncableParameter {
|
||||
key: string;
|
||||
serverKey: string;
|
||||
type: 'number' | 'string' | 'boolean';
|
||||
canSync: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Mapping of webui setting keys to server parameter keys
|
||||
* Only parameters that should be synced from server are included
|
||||
*/
|
||||
export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
|
||||
{ key: 'temperature', serverKey: 'temperature', type: 'number', canSync: true },
|
||||
{ key: 'top_k', serverKey: 'top_k', type: 'number', canSync: true },
|
||||
{ key: 'top_p', serverKey: 'top_p', type: 'number', canSync: true },
|
||||
{ key: 'min_p', serverKey: 'min_p', type: 'number', canSync: true },
|
||||
{ key: 'dynatemp_range', serverKey: 'dynatemp_range', type: 'number', canSync: true },
|
||||
{ key: 'dynatemp_exponent', serverKey: 'dynatemp_exponent', type: 'number', canSync: true },
|
||||
{ key: 'xtc_probability', serverKey: 'xtc_probability', type: 'number', canSync: true },
|
||||
{ key: 'xtc_threshold', serverKey: 'xtc_threshold', type: 'number', canSync: true },
|
||||
{ key: 'typ_p', serverKey: 'typ_p', type: 'number', canSync: true },
|
||||
{ key: 'repeat_last_n', serverKey: 'repeat_last_n', type: 'number', canSync: true },
|
||||
{ key: 'repeat_penalty', serverKey: 'repeat_penalty', type: 'number', canSync: true },
|
||||
{ key: 'presence_penalty', serverKey: 'presence_penalty', type: 'number', canSync: true },
|
||||
{ key: 'frequency_penalty', serverKey: 'frequency_penalty', type: 'number', canSync: true },
|
||||
{ key: 'dry_multiplier', serverKey: 'dry_multiplier', type: 'number', canSync: true },
|
||||
{ key: 'dry_base', serverKey: 'dry_base', type: 'number', canSync: true },
|
||||
{ key: 'dry_allowed_length', serverKey: 'dry_allowed_length', type: 'number', canSync: true },
|
||||
{ key: 'dry_penalty_last_n', serverKey: 'dry_penalty_last_n', type: 'number', canSync: true },
|
||||
{ key: 'max_tokens', serverKey: 'max_tokens', type: 'number', canSync: true },
|
||||
{ key: 'samplers', serverKey: 'samplers', type: 'string', canSync: true },
|
||||
{
|
||||
key: 'pasteLongTextToFileLen',
|
||||
serverKey: 'pasteLongTextToFileLen',
|
||||
type: 'number',
|
||||
canSync: true
|
||||
},
|
||||
{ key: 'pdfAsImage', serverKey: 'pdfAsImage', type: 'boolean', canSync: true },
|
||||
{
|
||||
key: 'showThoughtInProgress',
|
||||
serverKey: 'showThoughtInProgress',
|
||||
type: 'boolean',
|
||||
canSync: true
|
||||
},
|
||||
{ key: 'showToolCalls', serverKey: 'showToolCalls', type: 'boolean', canSync: true },
|
||||
{ key: 'keepStatsVisible', serverKey: 'keepStatsVisible', type: 'boolean', canSync: true },
|
||||
{ key: 'showMessageStats', serverKey: 'showMessageStats', type: 'boolean', canSync: true },
|
||||
{
|
||||
key: 'askForTitleConfirmation',
|
||||
serverKey: 'askForTitleConfirmation',
|
||||
type: 'boolean',
|
||||
canSync: true
|
||||
},
|
||||
{ key: 'disableAutoScroll', serverKey: 'disableAutoScroll', type: 'boolean', canSync: true },
|
||||
{
|
||||
key: 'renderUserContentAsMarkdown',
|
||||
serverKey: 'renderUserContentAsMarkdown',
|
||||
type: 'boolean',
|
||||
canSync: true
|
||||
},
|
||||
{ key: 'autoMicOnEmpty', serverKey: 'autoMicOnEmpty', type: 'boolean', canSync: true },
|
||||
{
|
||||
key: 'pyInterpreterEnabled',
|
||||
serverKey: 'pyInterpreterEnabled',
|
||||
type: 'boolean',
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'enableContinueGeneration',
|
||||
serverKey: 'enableContinueGeneration',
|
||||
type: 'boolean',
|
||||
canSync: true
|
||||
}
|
||||
];
|
||||
|
||||
export class ParameterSyncService {
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Extraction
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Round floating-point numbers to avoid JavaScript precision issues
|
||||
*/
|
||||
private static roundFloatingPoint(value: ParameterValue): ParameterValue {
|
||||
return normalizeFloatingPoint(value) as ParameterValue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract server default parameters that can be synced
|
||||
*/
|
||||
static extractServerDefaults(
|
||||
serverParams: ApiLlamaCppServerProps['default_generation_settings']['params'] | null,
|
||||
webuiSettings?: Record<string, string | number | boolean>
|
||||
): ParameterRecord {
|
||||
const extracted: ParameterRecord = {};
|
||||
|
||||
if (serverParams) {
|
||||
for (const param of SYNCABLE_PARAMETERS) {
|
||||
if (param.canSync && param.serverKey in serverParams) {
|
||||
const value = (serverParams as unknown as Record<string, ParameterValue>)[
|
||||
param.serverKey
|
||||
];
|
||||
if (value !== undefined) {
|
||||
// Apply precision rounding to avoid JavaScript floating-point issues
|
||||
extracted[param.key] = this.roundFloatingPoint(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle samplers array conversion to string
|
||||
if (serverParams.samplers && Array.isArray(serverParams.samplers)) {
|
||||
extracted.samplers = serverParams.samplers.join(';');
|
||||
}
|
||||
}
|
||||
|
||||
if (webuiSettings) {
|
||||
for (const param of SYNCABLE_PARAMETERS) {
|
||||
if (param.canSync && param.serverKey in webuiSettings) {
|
||||
const value = webuiSettings[param.serverKey];
|
||||
if (value !== undefined) {
|
||||
extracted[param.key] = this.roundFloatingPoint(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return extracted;
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Merging
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Merge server defaults with current user settings
|
||||
* Returns updated settings that respect user overrides while using server defaults
|
||||
*/
|
||||
static mergeWithServerDefaults(
|
||||
currentSettings: ParameterRecord,
|
||||
serverDefaults: ParameterRecord,
|
||||
userOverrides: Set<string> = new Set()
|
||||
): ParameterRecord {
|
||||
const merged = { ...currentSettings };
|
||||
|
||||
for (const [key, serverValue] of Object.entries(serverDefaults)) {
|
||||
// Only update if user hasn't explicitly overridden this parameter
|
||||
if (!userOverrides.has(key)) {
|
||||
merged[key] = this.roundFloatingPoint(serverValue);
|
||||
}
|
||||
}
|
||||
|
||||
return merged;
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Info
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Get parameter information including source and values
|
||||
*/
|
||||
static getParameterInfo(
|
||||
key: string,
|
||||
currentValue: ParameterValue,
|
||||
propsDefaults: ParameterRecord,
|
||||
userOverrides: Set<string>
|
||||
): ParameterInfo {
|
||||
const hasPropsDefault = propsDefaults[key] !== undefined;
|
||||
const isUserOverride = userOverrides.has(key);
|
||||
|
||||
// Simple logic: either using default (from props) or custom (user override)
|
||||
const source: ParameterSource = isUserOverride ? 'custom' : 'default';
|
||||
|
||||
return {
|
||||
value: currentValue,
|
||||
source,
|
||||
serverDefault: hasPropsDefault ? propsDefaults[key] : undefined, // Keep same field name for compatibility
|
||||
userOverride: isUserOverride ? currentValue : undefined
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a parameter can be synced from server
|
||||
*/
|
||||
static canSyncParameter(key: string): boolean {
|
||||
return SYNCABLE_PARAMETERS.some((param) => param.key === key && param.canSync);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all syncable parameter keys
|
||||
*/
|
||||
static getSyncableParameterKeys(): string[] {
|
||||
return SYNCABLE_PARAMETERS.filter((param) => param.canSync).map((param) => param.key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate server parameter value
|
||||
*/
|
||||
static validateServerParameter(key: string, value: ParameterValue): boolean {
|
||||
const param = SYNCABLE_PARAMETERS.find((p) => p.key === key);
|
||||
if (!param) return false;
|
||||
|
||||
switch (param.type) {
|
||||
case 'number':
|
||||
return typeof value === 'number' && !isNaN(value);
|
||||
case 'string':
|
||||
return typeof value === 'string';
|
||||
case 'boolean':
|
||||
return typeof value === 'boolean';
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Diff
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Create a diff between current settings and server defaults
|
||||
*/
|
||||
static createParameterDiff(
|
||||
currentSettings: ParameterRecord,
|
||||
serverDefaults: ParameterRecord
|
||||
): Record<string, { current: ParameterValue; server: ParameterValue; differs: boolean }> {
|
||||
const diff: Record<
|
||||
string,
|
||||
{ current: ParameterValue; server: ParameterValue; differs: boolean }
|
||||
> = {};
|
||||
|
||||
for (const key of this.getSyncableParameterKeys()) {
|
||||
const currentValue = currentSettings[key];
|
||||
const serverValue = serverDefaults[key];
|
||||
|
||||
if (serverValue !== undefined) {
|
||||
diff[key] = {
|
||||
current: currentValue,
|
||||
server: serverValue,
|
||||
differs: currentValue !== serverValue
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return diff;
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
import { apiFetchWithParams } from '$lib/utils/api-fetch';
|
||||
import { apiFetchWithParams } from '$lib/utils';
|
||||
|
||||
export class PropsService {
|
||||
/**
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
import { getAuthHeaders } from '$lib/utils';
|
||||
|
||||
/**
|
||||
* PropsService - Server properties management
|
||||
*
|
||||
* This service handles communication with the /props endpoint to retrieve
|
||||
* server configuration, model information, and capabilities.
|
||||
*
|
||||
* **Responsibilities:**
|
||||
* - Fetch server properties from /props endpoint
|
||||
* - Handle API authentication
|
||||
* - Parse and validate server response
|
||||
*
|
||||
* **Used by:**
|
||||
* - serverStore: Primary consumer for server state management
|
||||
*/
|
||||
export class PropsService {
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Fetching
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Fetches server properties from the /props endpoint
|
||||
*
|
||||
* @param autoload - If false, prevents automatic model loading (default: false)
|
||||
* @returns {Promise<ApiLlamaCppServerProps>} Server properties
|
||||
* @throws {Error} If the request fails or returns invalid data
|
||||
*/
|
||||
static async fetch(autoload = false): Promise<ApiLlamaCppServerProps> {
|
||||
const url = new URL('./props', window.location.href);
|
||||
if (!autoload) {
|
||||
url.searchParams.set('autoload', 'false');
|
||||
}
|
||||
|
||||
const response = await fetch(url.toString(), {
|
||||
headers: getAuthHeaders()
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to fetch server properties: ${response.status} ${response.statusText}`
|
||||
);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return data as ApiLlamaCppServerProps;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches server properties for a specific model (ROUTER mode)
|
||||
*
|
||||
* @param modelId - The model ID to fetch properties for
|
||||
* @param autoload - If false, prevents automatic model loading (default: false)
|
||||
* @returns {Promise<ApiLlamaCppServerProps>} Server properties for the model
|
||||
* @throws {Error} If the request fails or returns invalid data
|
||||
*/
|
||||
static async fetchForModel(modelId: string, autoload = false): Promise<ApiLlamaCppServerProps> {
|
||||
const url = new URL('./props', window.location.href);
|
||||
url.searchParams.set('model', modelId);
|
||||
if (!autoload) {
|
||||
url.searchParams.set('autoload', 'false');
|
||||
}
|
||||
|
||||
const response = await fetch(url.toString(), {
|
||||
headers: getAuthHeaders()
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to fetch model properties: ${response.status} ${response.statusText}`
|
||||
);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return data as ApiLlamaCppServerProps;
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user