Compare commits

...

16 Commits
b8563 ... b8579

Author SHA1 Message Date
Gaurav Garg
ec16a072f0 Optimize MOE GEMV kernel for BS > 1. (#20905)
* Optimize MOE GEMV kernel for BS > 1.

The previous MOE kernel for BS > 1 had too many thread blocks (nrows_x, nchannels_dst, ncols_dst), with very little work per block. block of (32, 4) was doing inner dot product for a single row.

New mul_mat_vec_q_moe kernel is dedicated for MoE multi-token kernel with grid (ceil(nrows_x/rpb), nchannels_dst), block (warp_size, ncols_dst). Each warp handles two rows independently with warp-level reduction only (no shared memory sync).

This change doesn't increase any compilation time as a single template instance is needed per type. This also simplifies the original GEMV kernel and gets rid of `is_multi_token_id` specialization.

* Remove em-dashes

* Cherry-pick changes from @am17an PR https://github.com/ggml-org/llama.cpp/pull/20885 to enable small_k optimization only for cases where it benefits

Increase max batch size for MMVQ kernels for MUL_MAT_ID to 8

* Make the max batch size for MOE GEMV kernel configurable based on GPU arch and datatype

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
2026-03-29 18:35:18 +02:00
Max Krasnyansky
f5d1c4179f hexagon: dma optimizations (mostly fixing regressions) (#21137)
* hex-fa: add simple dma cache for Mask

I noticed that we were refetch the mask rows over and over.
This simple cache avoids that.

* hex-dma: unset in-order desc bit which caused signficant perf regression

We don't rely on true in order processing of the DMA descriptors anywhere.
Turns out this mode caused significant regression of around 3-4 TPS during token gen.

* hex-rope: update comment to clarify that we don't need in-order DMA completions
2026-03-29 06:40:13 -07:00
Davi Henrique Linhares
2405d59cb6 devops: including compute-runtime for intel.Dockerfile (#21076) 2026-03-29 13:34:03 +08:00
Neo Zhang
afe65aa282 [SYCL] Enhance build script to use half cores to build, avoid OS hang (#21093)
* use half cores to build, avoid OS hang

* reduce the output text num to short test time

* avoid to return 0
2026-03-29 09:02:45 +08:00
Sigbjørn Skjæret
65097181e4 fix **/x glob matching (#21129) 2026-03-28 22:27:38 +01:00
Piotr Wilkin (ilintar)
98ae0a0d36 common/parser: fix handling of tool definition with missing properties key (#21128) 2026-03-28 20:41:32 +01:00
Sigbjørn Skjæret
3a14a542f5 common : add character class support to glob_match (#21111)
* add character class support to glob_match

* remove pointless reference
2026-03-28 19:57:37 +01:00
BlueMöhre
968189729f WebUI: Replace illegal nested button elements (#21026)
* remove/replace nested button elements

* map rest props to outer element

* solve TODO

* chore: update webui build output
2026-03-28 17:57:59 +01:00
Adrien
e397d3885c common/json-schema: fix: handle non-capturing groups (?:...) in JSON schema pattern converter (#21124)
The regex-to-grammar converter in _visit_pattern() crashes with SIGSEGV
when a JSON schema "pattern" field contains a non-capturing group (?:...).

Root cause: when the parser sees '(' followed by '?', it pushes a warning
but does not advance past '?:'. The recursive transform() call then
interprets '?' as a quantifier and calls seq.back() on an empty vector,
causing undefined behavior.

This commonly occurs when serving OpenAI-compatible tool calls from
clients that include complex regex patterns in their JSON schemas (e.g.,
date validation patterns like ^(?:(?:\d\d[2468][048]|...)-02-29|...)$).

The fix:
- Skip '?:' after '(' to treat non-capturing groups as regular groups
- For unsupported syntax (?=, ?!, etc.), skip to matching ')' safely,
  handling escaped characters to avoid miscounting parenthesis depth
- Adjust the ')' unbalanced-parentheses check using direct char
  comparisons instead of substr
- Add test cases for non-capturing groups (C++ only, as the JS/Python
  implementations do not yet support this syntax)
2026-03-28 17:55:38 +01:00
Aldehir Rojas
e6f2ec01ff common : add reasoning_format = none support to gpt-oss (#21094) 2026-03-28 09:33:39 -05:00
Georgi Gerganov
edfb440a2f server : fix processing of multiple back-to-back mtmd chunks (#21107) 2026-03-28 16:27:36 +02:00
Adrien Gallouët
3d66da1809 ci : gracefully shut down the server (#21110)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-03-28 14:49:57 +01:00
Woof Dog
82b703f8bc Document custom default webui preferences in server README (#19771) 2026-03-28 14:19:16 +01:00
Aleksander Grygier
51a84efc53 webui: Conversation forking + branching improvements (#21021)
* refactor: Make `DialogConfirmation` extensible with children slot

* feat: Add conversation forking logic

* feat: Conversation forking UI

* feat: Update delete/edit dialogs and logic for forks

* refactor: Improve Chat Sidebar UX and add MCP Servers entry

* refactor: Cleanup

* feat: Update message in place when editing leaf nodes

* chore: Cleanup

* chore: Cleanup

* chore: Cleanup

* chore: Cleanup

* chore: Cleanup

* chore: Cleanup

* refactor: Post-review improvements

* chore: update webui build output

* test: Update Storybook test

* chore: update webui build output

* chore: update webui build output
2026-03-28 13:38:15 +01:00
Adrien Gallouët
b0f0dd3e51 vendor : update cpp-httplib to 0.40.0 (#21100)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-03-28 08:59:44 +01:00
Ruben Ortlam
0eb4764182 vulkan: add noncontiguous GLU support (#21081)
* vulkan: add noncontiguous GLU support

* fix compile issue
2026-03-28 08:44:56 +01:00
53 changed files with 1784 additions and 332 deletions

View File

@@ -33,6 +33,23 @@ RUN mkdir -p /app/full \
FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS base
ARG IGC_VERSION=v2.30.1
ARG IGC_VERSION_FULL=2_2.30.1+20950
ARG COMPUTE_RUNTIME_VERSION=26.09.37435.1
ARG COMPUTE_RUNTIME_VERSION_FULL=26.09.37435.1-0
ARG IGDGMM_VERSION=22.9.0
RUN mkdir /tmp/neo/ && cd /tmp/neo/ \
&& wget https://github.com/intel/intel-graphics-compiler/releases/download/$IGC_VERSION/intel-igc-core-${IGC_VERSION_FULL}_amd64.deb \
&& wget https://github.com/intel/intel-graphics-compiler/releases/download/$IGC_VERSION/intel-igc-opencl-${IGC_VERSION_FULL}_amd64.deb \
&& wget https://github.com/intel/compute-runtime/releases/download/$COMPUTE_RUNTIME_VERSION/intel-ocloc-dbgsym_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.ddeb \
&& wget https://github.com/intel/compute-runtime/releases/download/$COMPUTE_RUNTIME_VERSION/intel-ocloc_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.deb \
&& wget https://github.com/intel/compute-runtime/releases/download/$COMPUTE_RUNTIME_VERSION/intel-opencl-icd-dbgsym_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.ddeb \
&& wget https://github.com/intel/compute-runtime/releases/download/$COMPUTE_RUNTIME_VERSION/intel-opencl-icd_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.deb \
&& wget https://github.com/intel/compute-runtime/releases/download/$COMPUTE_RUNTIME_VERSION/libigdgmm12_${IGDGMM_VERSION}_amd64.deb \
&& wget https://github.com/intel/compute-runtime/releases/download/$COMPUTE_RUNTIME_VERSION/libze-intel-gpu1-dbgsym_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.ddeb \
&& wget https://github.com/intel/compute-runtime/releases/download/$COMPUTE_RUNTIME_VERSION/libze-intel-gpu1_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.deb \
&& dpkg --install *.deb
RUN apt-get update \
&& apt-get install -y libgomp1 curl\
&& apt autoremove -y \

View File

@@ -65,7 +65,7 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.at("parameters");
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
builder.resolve_refs(schema);
});
parser.build_grammar(builder, data.grammar_lazy);
@@ -221,7 +221,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context
foreach_function(inputs.tools, [&](const json & tool) {
const auto & func = tool.at("function");
std::string name = func.at("name");
const auto & schema = func.at("parameters");
const auto & schema = func.contains("parameters") ? func.at("parameters") : json::object();
// Build call_id parser based on position (if supported)
common_peg_parser call_id_section = p.eps();
@@ -282,19 +282,11 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
common_peg_parser tool_choice = p.choice();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & func = tool.at("function");
std::string name = func.at("name");
const auto & params = func.at("parameters");
if (!params.contains("properties") || !params.at("properties").is_object()) {
return;
}
const auto & properties = params.at("properties");
const auto & func = tool.at("function");
std::string name = func.at("name");
const auto & params = func.contains("parameters") ? func.at("parameters") : json::object();
const auto & properties = params.contains("properties") ? params.at("properties") : json::object();
std::set<std::string> required;
if (params.contains("required") && params.at("required").is_array()) {
params.at("required").get_to(required);
}
// Build parser for each argument, separating required and optional
std::vector<common_peg_parser> required_parsers;
@@ -311,17 +303,18 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
}
}
auto arg = p.tool_arg(
p.tool_arg_open(arguments.name_prefix + p.tool_arg_name(p.literal(param_name)) +
arguments.name_suffix) +
arguments.value_prefix +
(type == "string" ? p.tool_arg_string_value(p.schema(p.until(arguments.value_suffix),
"tool-" + name + "-arg-" + param_name + "-schema",
param_schema, true)) :
p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
p.space()) +
p.tool_arg_close(p.literal(arguments.value_suffix)));
auto arg =
p.tool_arg(p.tool_arg_open(arguments.name_prefix + p.tool_arg_name(p.literal(param_name)) +
arguments.name_suffix) +
arguments.value_prefix +
(type == "string" ?
p.tool_arg_string_value(p.schema(p.until(arguments.value_suffix),
"tool-" + name + "-arg-" + param_name + "-schema",
param_schema, true)) :
p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
p.space()) +
p.tool_arg_close(p.literal(arguments.value_suffix)));
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
if (is_required) {

View File

@@ -971,6 +971,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto start = p.rule("start", p.literal("<|start|>assistant"));
@@ -979,7 +980,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
auto channel = p.literal("<|channel|>") + (p.literal("commentary") | p.literal("analysis"));
auto constrain_type = p.chars("[A-Za-z0-9_-]", 1, -1);
auto analysis = p.rule("analysis", p.literal("<|channel|>analysis<|message|>") + p.reasoning(content) + end);
if (extract_reasoning) {
p.rule("analysis", p.literal("<|channel|>analysis<|message|>") + p.reasoning(content) + end);
} else {
p.rule("analysis", p.content(p.literal("<|channel|>analysis<|message|>") + content + end));
}
auto analysis = p.ref("analysis");
auto preamble = p.rule("preamble", p.literal("<|channel|>commentary<|message|>") + p.content(content) + end);
auto final_msg = p.rule("final", p.literal("<|channel|>final<|message|>") + p.content(content));
auto any = p.rule("any", preamble | analysis);

View File

@@ -656,14 +656,53 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
return true;
}
// simple glob: * matches non-/ chars, ** matches anything including /
static inline bool glob_class_match(const char c, const char * pattern, const char * class_end) {
const char * class_start = pattern;
bool negated = false;
if (*class_start == '!') {
negated = true;
class_start++;
}
// If first character after negation is ']' or '-', treat it as literal
if (*class_start == ']' || *class_start == '-') {
if (class_start < class_end && *class_start == c) {
return !negated;
}
class_start++;
}
bool matched = false;
while (class_start < class_end) {
if (class_start + 2 < class_end && class_start[1] == '-' && class_start[2] != ']') {
char start_char = *class_start;
char end_char = class_start[2];
if (c >= start_char && c <= end_char) {
matched = true;
break;
}
class_start += 3;
} else {
if (*class_start == c) {
matched = true;
break;
}
class_start++;
}
}
return negated ? !matched : matched;
}
// simple glob: * matches non-/ chars, ** matches anything including /, [] matches character class
static inline bool glob_match(const char * pattern, const char * str) {
if (*pattern == '\0') {
return *str == '\0';
}
if (pattern[0] == '*' && pattern[1] == '*') {
const char * p = pattern + 2;
if (*p == '/') p++;
if (glob_match(p, str)) return true;
if (*str != '\0') return glob_match(pattern, str + 1);
return false;
@@ -678,6 +717,26 @@ static inline bool glob_match(const char * pattern, const char * str) {
if (*pattern == '?' && *str != '\0' && *str != '/') {
return glob_match(pattern + 1, str + 1);
}
if (*pattern == '[') {
const char * class_end = pattern + 1;
// If first character after '[' is ']' or '-', treat it as literal
if (*class_end == ']' || *class_end == '-') {
class_end++;
}
while (*class_end != '\0' && *class_end != ']') {
class_end++;
}
if (*class_end == ']') {
if (*str == '\0') return false;
bool matched = glob_class_match(*str, pattern + 1, class_end);
return matched && glob_match(class_end + 1, str + 1);
} else {
if (*str == '[') {
return glob_match(pattern + 1, str + 1);
}
return false;
}
}
if (*pattern == *str) {
return glob_match(pattern + 1, str + 1);
}

View File

@@ -416,15 +416,30 @@ private:
i++;
} else if (c == '(') {
i++;
if (i < length) {
if (sub_pattern[i] == '?') {
if (i < length && sub_pattern[i] == '?') {
if (i + 1 < length && sub_pattern[i + 1] == ':') {
i += 2; // skip "?:" for non-capturing group, treat as regular group
} else {
// lookahead/lookbehind (?=, ?!, ?<=, ?<!) - not supported
_warnings.push_back("Unsupported pattern syntax");
// skip to matching ')' to avoid UB on empty seq
int depth = 1;
while (i < length && depth > 0) {
if (sub_pattern[i] == '\\' && i + 1 < length) {
i += 2; // skip escaped character
} else {
if (sub_pattern[i] == '(') depth++;
else if (sub_pattern[i] == ')') depth--;
i++;
}
}
continue;
}
}
seq.emplace_back("(" + to_rule(transform()) + ")", false);
} else if (c == ')') {
i++;
if (start > 0 && sub_pattern[start - 1] != '(') {
if (start > 0 && sub_pattern[start - 1] != '(' && (start < 2 || sub_pattern[start - 2] != '?' || sub_pattern[start - 1] != ':')) {
_errors.push_back("Unbalanced parentheses");
}
return join_seq();

View File

@@ -20,4 +20,4 @@ cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA
#cmake --build . --config Release --target llama-bench
#build all binary
cmake --build . --config Release -j -v
cmake --build . --config Release -j$((($(nproc)+1)/2)) -v

View File

@@ -23,9 +23,9 @@ if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1
echo "use $GGML_SYCL_DEVICE as main GPU"
#use signle GPU only
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none ${LOAD_MODE}
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 200 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none ${LOAD_MODE}
else
#use multiple GPUs with same max compute units
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} ${LOAD_MODE}
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 200 -e -ngl ${NGL} -s 0 -c ${CONTEXT} ${LOAD_MODE}
fi

View File

@@ -2343,7 +2343,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
if (ggml_is_quantized(src0->type)) {
if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
const int mmvq_mmid_max = get_mmvq_mmid_max_batch(src0->type, cc);
if (ne2 <= mmvq_mmid_max) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
return;
}
@@ -2946,14 +2947,18 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
}
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
use_cuda_graph = false;
if (node->op == GGML_OP_MUL_MAT_ID) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc);
if (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max) {
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
#endif
}
}
if (!use_cuda_graph) {

View File

@@ -97,6 +97,194 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
return MMVQ_PARAMETERS_GENERIC;
}
// Per-architecture maximum batch size for which MMVQ should be used for MUL_MAT_ID.
// Returns a value <= MMVQ_MAX_BATCH_SIZE. Default is MMVQ_MAX_BATCH_SIZE.
// Check https://github.com/ggml-org/llama.cpp/pull/20905#issuecomment-4145835627 for details
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(ggml_type type) {
switch (type) {
case GGML_TYPE_IQ1_S: return 6;
case GGML_TYPE_IQ1_M: return 6;
case GGML_TYPE_IQ2_S: return 4;
case GGML_TYPE_IQ2_XS: return 5;
case GGML_TYPE_IQ2_XXS: return 5;
case GGML_TYPE_IQ3_S: return 4;
case GGML_TYPE_IQ3_XXS: return 4;
case GGML_TYPE_IQ4_NL: return 6;
case GGML_TYPE_IQ4_XS: return 5;
case GGML_TYPE_MXFP4: return 4;
case GGML_TYPE_Q2_K: return 4;
case GGML_TYPE_Q3_K: return 4;
case GGML_TYPE_Q4_0: return 6;
case GGML_TYPE_Q4_1: return 6;
case GGML_TYPE_Q4_K: return 5;
case GGML_TYPE_Q5_0: return 6;
case GGML_TYPE_Q5_1: return 6;
case GGML_TYPE_Q5_K: return 5;
case GGML_TYPE_Q6_K: return 4;
case GGML_TYPE_Q8_0: return 4;
default: return MMVQ_MAX_BATCH_SIZE;
}
}
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggml_type type) {
switch (type) {
case GGML_TYPE_IQ2_S: return 7;
case GGML_TYPE_IQ3_S: return 6;
case GGML_TYPE_IQ3_XXS: return 7;
case GGML_TYPE_MXFP4: return 7;
case GGML_TYPE_Q2_K: return 7;
case GGML_TYPE_Q3_K: return 5;
default: return MMVQ_MAX_BATCH_SIZE;
}
}
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_gcn(ggml_type type) {
switch (type) {
case GGML_TYPE_IQ1_S: return 5;
case GGML_TYPE_IQ1_M: return 5;
case GGML_TYPE_IQ2_S: return 4;
case GGML_TYPE_IQ2_XS: return 4;
case GGML_TYPE_IQ2_XXS: return 4;
case GGML_TYPE_IQ3_S: return 4;
case GGML_TYPE_IQ3_XXS: return 4;
case GGML_TYPE_IQ4_NL: return 6;
case GGML_TYPE_IQ4_XS: return 4;
case GGML_TYPE_Q2_K: return 4;
case GGML_TYPE_Q3_K: return 4;
case GGML_TYPE_Q4_0: return 5;
case GGML_TYPE_Q4_1: return 5;
case GGML_TYPE_Q4_K: return 4;
case GGML_TYPE_Q5_K: return 4;
case GGML_TYPE_Q6_K: return 4;
case GGML_TYPE_Q8_0: return 4;
default: return MMVQ_MAX_BATCH_SIZE;
}
}
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_cdna(ggml_type type) {
switch (type) {
case GGML_TYPE_IQ2_S: return 5;
case GGML_TYPE_IQ2_XS: return 5;
case GGML_TYPE_IQ2_XXS: return 5;
case GGML_TYPE_IQ3_S: return 4;
case GGML_TYPE_IQ3_XXS: return 5;
default: return MMVQ_MAX_BATCH_SIZE;
}
}
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna1_rdna2(ggml_type type) {
switch (type) {
case GGML_TYPE_IQ2_S: return 4;
case GGML_TYPE_IQ2_XS: return 4;
case GGML_TYPE_IQ2_XXS: return 4;
case GGML_TYPE_IQ3_S: return 4;
case GGML_TYPE_IQ3_XXS: return 4;
case GGML_TYPE_Q2_K: return 7;
case GGML_TYPE_Q3_K: return 4;
case GGML_TYPE_Q4_K: return 5;
case GGML_TYPE_Q5_K: return 6;
case GGML_TYPE_Q6_K: return 5;
default: return MMVQ_MAX_BATCH_SIZE;
}
}
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna3(ggml_type type) {
switch (type) {
case GGML_TYPE_IQ1_S: return 6;
case GGML_TYPE_IQ1_M: return 6;
case GGML_TYPE_IQ2_S: return 4;
case GGML_TYPE_IQ2_XS: return 4;
case GGML_TYPE_IQ2_XXS: return 4;
case GGML_TYPE_IQ3_S: return 4;
case GGML_TYPE_IQ3_XXS: return 4;
case GGML_TYPE_IQ4_NL: return 6;
case GGML_TYPE_IQ4_XS: return 6;
case GGML_TYPE_Q4_K: return 4;
case GGML_TYPE_Q5_K: return 4;
case GGML_TYPE_Q6_K: return 4;
default: return MMVQ_MAX_BATCH_SIZE;
}
}
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type type) {
switch (type) {
case GGML_TYPE_IQ1_S: return 7;
case GGML_TYPE_IQ1_M: return 7;
case GGML_TYPE_IQ2_S: return 4;
case GGML_TYPE_IQ2_XS: return 4;
case GGML_TYPE_IQ2_XXS: return 4;
case GGML_TYPE_IQ3_S: return 4;
case GGML_TYPE_IQ3_XXS: return 4;
case GGML_TYPE_IQ4_NL: return 7;
case GGML_TYPE_IQ4_XS: return 5;
case GGML_TYPE_MXFP4: return 5;
case GGML_TYPE_Q3_K: return 4;
case GGML_TYPE_Q4_0: return 7;
case GGML_TYPE_Q4_1: return 7;
case GGML_TYPE_Q4_K: return 4;
case GGML_TYPE_Q5_0: return 7;
case GGML_TYPE_Q5_1: return 7;
case GGML_TYPE_Q5_K: return 5;
case GGML_TYPE_Q6_K: return 5;
case GGML_TYPE_Q8_0: return 7;
default: return MMVQ_MAX_BATCH_SIZE;
}
}
// Host function: returns the max batch size for the current arch+type at runtime.
int get_mmvq_mmid_max_batch(ggml_type type, int cc) {
// NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID.
if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return MMVQ_MAX_BATCH_SIZE;
}
if (cc >= GGML_CUDA_CC_TURING) {
return get_mmvq_mmid_max_batch_turing_plus(type);
}
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
return get_mmvq_mmid_max_batch_pascal_older(type);
}
// AMD
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return get_mmvq_mmid_max_batch_rdna4(type);
}
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
return get_mmvq_mmid_max_batch_rdna3(type);
}
if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
}
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return get_mmvq_mmid_max_batch_cdna(type);
}
if (GGML_CUDA_CC_IS_GCN(cc)) {
return get_mmvq_mmid_max_batch_gcn(type);
}
return MMVQ_MAX_BATCH_SIZE;
}
// Device constexpr: returns the max batch size for the current arch+type at compile time.
template <ggml_type type>
static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() {
#if defined(RDNA4)
return get_mmvq_mmid_max_batch_rdna4(type);
#elif defined(RDNA3)
return get_mmvq_mmid_max_batch_rdna3(type);
#elif defined(RDNA2) || defined(RDNA1)
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
#elif defined(CDNA)
return get_mmvq_mmid_max_batch_cdna(type);
#elif defined(GCN)
return get_mmvq_mmid_max_batch_gcn(type);
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ >= GGML_CUDA_CC_ADA_LOVELACE)
return MMVQ_MAX_BATCH_SIZE;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
return get_mmvq_mmid_max_batch_turing_plus(type);
#else
return get_mmvq_mmid_max_batch_pascal_older(type);
#endif
}
static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC) {
switch (ncols_dst) {
@@ -195,7 +383,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
return 1;
}
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false, bool small_k = false>
template <ggml_type type, int ncols_dst, bool has_fusion, bool small_k = false>
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
@@ -222,22 +410,13 @@ static __global__ void mul_mat_vec_q(
const uint32_t channel_dst = blockIdx.y;
uint32_t token_idx = 0;
uint32_t channel_x;
uint32_t channel_y;
uint32_t sample_dst;
if constexpr (is_multi_token_id) {
// Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
token_idx = blockIdx.z;
channel_x = ids[channel_dst + token_idx * ids_stride];
channel_y = fastmodulo(channel_dst, nchannels_y);
sample_dst = 0;
} else {
channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
sample_dst = blockIdx.z;
}
channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
sample_dst = blockIdx.z;
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
const uint32_t sample_y = sample_dst;
@@ -294,9 +473,6 @@ static __global__ void mul_mat_vec_q(
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
if constexpr (is_multi_token_id) {
y += token_idx*stride_col_y;
}
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@@ -350,10 +526,6 @@ static __global__ void mul_mat_vec_q(
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
if constexpr (is_multi_token_id) {
dst += token_idx*stride_col_dst;
}
// sum up partial sums and write back result
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
@@ -413,6 +585,69 @@ static __global__ void mul_mat_vec_q(
}
}
// Dedicated MoE multi-token kernel.
// Grid: (ceil(nrows_x / c_rows_per_block), nchannels_dst)
// Block: (warp_size, ncols_dst) - each warp handles one token independently.
// No shared memory reduction needed since each warp works alone.
template <ggml_type type, int c_rows_per_block>
__launch_bounds__(get_mmvq_mmid_max_batch_for_device<type>()*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q_moe(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids,
float * __restrict__ dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x,
const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
const uint32_t ncols_dst, const uint32_t ids_stride) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
const uint32_t token_idx = threadIdx.y;
const int row0 = c_rows_per_block*blockIdx.x;
const int blocks_per_row_x = ncols_x / qk;
constexpr int blocks_per_iter = vdr * warp_size / qi;
const uint32_t channel_dst = blockIdx.y;
if (token_idx >= ncols_dst) {
return;
}
const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride];
const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y);
const block_q8_1 * y = ((const block_q8_1 *) vy) + channel_y*stride_channel_y + token_idx*stride_col_y;
const int kbx_offset = channel_x*stride_channel_x + row0*stride_row_x;
// partial sum for each thread
float tmp[c_rows_per_block] = {0.0f};
for (int kbx = threadIdx.x / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
const int kby = kbx * (qk/QK8_1);
const int kqs = vdr * (threadIdx.x % (qi/vdr));
#pragma unroll
for (int i = 0; i < c_rows_per_block; ++i) {
tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_offset + i*stride_row_x + kbx, kqs);
}
}
// Warp-level reduction only - no shared memory needed
#pragma unroll
for (int i = 0; i < c_rows_per_block; ++i) {
tmp[i] = warp_reduce_sum<warp_size>(tmp[i]);
}
// Write results
if (threadIdx.x < c_rows_per_block && (c_rows_per_block == 1 || uint32_t(row0 + threadIdx.x) < nrows_x)) {
dst[channel_dst*stride_channel_dst + token_idx*stride_col_dst + row0 + threadIdx.x] = tmp[threadIdx.x];
}
}
template<ggml_type type>
static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
@@ -425,7 +660,7 @@ static std::pair<dim3, dim3> calc_launch_params(
return {block_nums, block_dims};
}
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false, bool small_k = false>
template<ggml_type type, int c_ncols_dst, bool small_k = false>
static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@@ -438,7 +673,7 @@ static void mul_mat_vec_q_switch_fusion(
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_q<type, c_ncols_dst, true, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@@ -448,12 +683,33 @@ static void mul_mat_vec_q_switch_fusion(
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_q<type, c_ncols_dst, false, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
}
template <ggml_type type>
static void mul_mat_vec_q_moe_launch(
const void * vx, const void * vy, const int32_t * ids, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x,
const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
const uint32_t ncols_dst, const uint32_t ids_stride,
const int warp_size, const int nchannels_dst, cudaStream_t stream) {
constexpr int rows_per_block = 2; // 2 gives best perf based on tuning
const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block;
const dim3 block_nums(nblocks_rows, nchannels_dst);
const dim3 block_dims(warp_size, ncols_dst);
mul_mat_vec_q_moe<type, rows_per_block><<<block_nums, block_dims, 0, stream>>>(
vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x,
stride_row_x, stride_col_y, stride_col_dst,
stride_channel_x, stride_channel_y, stride_channel_dst,
ncols_dst, ids_stride);
}
template <ggml_type type>
static void mul_mat_vec_q_switch_ncols_dst(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
@@ -472,20 +728,62 @@ static void mul_mat_vec_q_switch_ncols_dst(
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
const int device = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[device].cc;
const int warp_size = ggml_cuda_info().devices[device].warp_size;
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
const mmvq_parameter_table_id table_id = get_device_table_id(cc);
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
const bool has_ids = ids != nullptr;
const auto should_use_small_k = [&](int c_ncols_dst) {
// When K is small, increase rows_per_block to match nwarps so each warp has more work to do
// Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_iter_1warp = vdr * warp_size / qi;
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
bool use = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
constexpr std::array<ggml_type, 2> iq_slow_turing = {
GGML_TYPE_IQ3_XXS,
GGML_TYPE_IQ3_S,
};
constexpr std::array<ggml_type, 8> iq_slow_other = {
GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS,
GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
};
constexpr std::array<ggml_type, 3> slow_pascal = {
GGML_TYPE_IQ3_S,
GGML_TYPE_Q2_K,
GGML_TYPE_Q3_K,
};
const bool is_nvidia_turing_plus = GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_TURING;
const bool is_nvidia_pascal_older = GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA;
if (is_nvidia_turing_plus) {
if (ncols_dst == 1 &&
std::find(iq_slow_turing.begin(), iq_slow_turing.end(), type) != iq_slow_turing.end()) {
use = false;
}
} else if ((ncols_dst == 1 && std::find(iq_slow_other.begin(), iq_slow_other.end(), type) != iq_slow_other.end()) ||
(is_nvidia_pascal_older && std::find(slow_pascal.begin(), slow_pascal.end(), type) != slow_pascal.end()) ||
GGML_CUDA_CC_IS_RDNA(cc)) {
use = false;
}
return use;
};
if (has_ids && ncols_dst > 1) {
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
// Multi-token MUL_MAT_ID path - dedicated MoE kernel
mul_mat_vec_q_moe_launch<type>(
vx, vy, ids, dst, ncols_x, nchannels_y_fd, nrows_x,
stride_row_x, stride_col_y, stride_col_dst,
stride_channel_x, stride_channel_y, stride_channel_dst,
ncols_dst, ids_stride, warp_size, nchannels_dst, stream);
return;
}
@@ -493,31 +791,24 @@ static void mul_mat_vec_q_switch_ncols_dst(
case 1: {
constexpr int c_ncols_dst = 1;
// When K is small, increase rows_per_block to match nwarps so each warp has more work to do
// Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_iter_1warp = vdr * warp_size / qi;
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
bool use_small_k = should_use_small_k(c_ncols_dst);
if (use_small_k) {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
warp_size, table_id, true);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, false, true>(
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
nsamples_dst, warp_size, table_id, true);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
stream);
} else {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
stream);
}
} break;
case 2: {

View File

@@ -1,7 +1,10 @@
#include "common.cuh"
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID
// Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID,
// based on the quantization type and GPU architecture (compute capability).
int get_mmvq_mmid_max_batch(ggml_type type, int cc);
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);

View File

@@ -346,6 +346,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
dma_cache m_cache;
dma_cache_init(&m_cache, spad_m, factx->size_m_block, DMA_CACHE_MAX_SIZE);
for (uint32_t ir = ir0; ir < ir1; ++ir) {
const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
@@ -389,9 +392,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
// Mask is 1D contiguous for this row
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
dma_cache_push(dma, &m_cache, m_src, current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
}
// FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
@@ -554,7 +556,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
dma_cache_push(dma, &m_cache, m_src, next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
}
// FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
@@ -684,7 +686,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
octx->src0_spad.size_per_thread = size_q_block * 1;
octx->src1_spad.size_per_thread = factx.size_k_block * 2;
octx->src2_spad.size_per_thread = factx.size_v_block * 2;
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * DMA_CACHE_MAX_SIZE : 0;
octx->dst_spad.size_per_thread = size_vkq_acc;
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
@@ -705,6 +707,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
// FARF(ERROR, "fa: qrows-per-thread %u", factx.qrows_per_thread);
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
}

View File

@@ -143,7 +143,7 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
desc->desc_size = 0; // 1D mode
desc->src_bypass = dma_src_l2_bypass_on;
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->order = 1;
desc->order = 0;
desc->done = 0;
desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst;
@@ -151,8 +151,12 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
q->dptr[q->push_idx] = dptr;
dmlink(q->tail, desc);
q->tail = (dma_descriptor_2d *) desc;
if (size) {
dmlink(q->tail, desc);
q->tail = (dma_descriptor_2d *) desc;
} else {
desc->done = 1;
}
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
@@ -175,7 +179,7 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->src_comp = 0;
desc->dst_comp = 0;
desc->order = 1;
desc->order = 0;
desc->done = 0;
desc->src_stride = src_stride;
desc->dst_stride = dst_stride;
@@ -197,8 +201,12 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t
q->dptr[q->push_idx] = dptr;
dmlink(q->tail, desc);
q->tail = desc;
if (nrows) {
dmlink(q->tail, desc);
q->tail = desc;
} else {
desc->done = 1;
}
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
@@ -215,12 +223,9 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
dma_descriptor_2d * desc = &q->desc[q->pop_idx];
// Wait for desc to complete
while (1) {
dmpoll();
if (desc->done) {
break;
}
while (!desc->done) {
// FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
dmpoll();
}
dptr = q->dptr[q->pop_idx];
@@ -312,6 +317,54 @@ static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
}
#define DMA_CACHE_MAX_SIZE 64U
typedef struct {
uint8_t *base;
uint32_t line_size;
uint32_t capacity;
uint32_t src[DMA_CACHE_MAX_SIZE];
uint16_t age[DMA_CACHE_MAX_SIZE];
} dma_cache;
static inline void dma_cache_init(dma_cache *c, uint8_t *base, uint32_t line_size, uint32_t capacity)
{
c->capacity = (capacity > DMA_CACHE_MAX_SIZE) ? DMA_CACHE_MAX_SIZE : capacity;
c->base = base;
c->line_size = line_size;
for (unsigned i=0; i < c->capacity; i++) {
c->src[i] = 0;
c->age[i] = 0;
}
}
static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * src, uint32_t dst_stride, uint32_t src_stride, uint32_t row_size, uint32_t nrows)
{
uint32_t o_idx = 0;
uint16_t o_age = 0;
uint8_t * dst = 0;
for (unsigned i=0; i < c->capacity; i++) {
if (c->src[i] == (uint32_t) src) {
c->age[i] = 0;
dst = c->base + (i * c->line_size); nrows = 0; // dummy dma
// FARF(ERROR, "dma-cache: found %p", src);
} else {
c->age[i]++;
if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; }
}
}
if (!dst) {
// FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src);
c->age[o_idx] = 0;
c->src[o_idx] = (uint32_t) src;
dst = c->base + o_idx * c->line_size; // normal nrows dma
}
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
}
#ifdef __cplusplus
} // extern "C"
#endif

View File

@@ -333,8 +333,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
}
// Skip DMA transactions from prev block (if any)
// No need to wait for these since the DMA is setup for in-order processing
// Skip output DMA transactions from prev block (if any)
// No need to wait for those here since we're explicitly waiting for the latest prefecthes below.
for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); }
// Compute loop

View File

@@ -1112,6 +1112,16 @@ struct vk_op_glu_push_constants {
uint32_t mode; // 0: default, 1: swapped, 2: split
float alpha; // for swiglu_oai
float limit;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t ne01;
uint32_t ne02;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t ne11;
uint32_t ne12;
};
struct vk_op_unary_push_constants {
@@ -5044,7 +5054,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
} else {
device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
}
vk::DeviceCreateInfo device_create_info;
vk::DeviceCreateInfo device_create_info{};
std::vector<const char *> device_extensions;
vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
@@ -5413,12 +5423,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
#endif
device->name = GGML_VK_NAME + std::to_string(idx);
device_create_info = {
vk::DeviceCreateFlags(),
device_queue_create_infos,
{},
device_extensions
};
device_create_info
.setFlags(vk::DeviceCreateFlags())
.setQueueCreateInfos(device_queue_create_infos)
.setPEnabledExtensionNames(device_extensions);
device_create_info.setPNext(&device_features2);
device->device = device->physical_device.createDevice(device_create_info);
@@ -11048,8 +11056,6 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
const float alpha = op_params_f[2];
const float limit = op_params_f[3];
GGML_ASSERT(ggml_is_contiguous(src0));
if (!split) {
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
} else {
@@ -11067,7 +11073,17 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
(uint32_t)dst->ne[0],
mode,
alpha,
limit
limit,
(uint32_t)(src0->nb[1] / src0->nb[0]),
(uint32_t)(src0->nb[2] / src0->nb[0]),
(uint32_t)(src0->nb[3] / src0->nb[0]),
(uint32_t)src0->ne[1],
(uint32_t)src0->ne[2],
(uint32_t)(dst->nb[1] / dst->nb[0]),
(uint32_t)(dst->nb[2] / dst->nb[0]),
(uint32_t)(dst->nb[3] / dst->nb[0]),
(uint32_t)dst->ne[1],
(uint32_t)dst->ne[2]
});
}
@@ -15217,8 +15233,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous(op->src[0]) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
(op->src[0]->type == op->type);
default:

View File

@@ -16,4 +16,14 @@ layout (push_constant) uniform parameter
uint mode;
float alpha;
float limit;
uint nb01;
uint nb02;
uint nb03;
uint ne01;
uint ne02;
uint nb11;
uint nb12;
uint nb13;
uint ne11;
uint ne12;
} p;

View File

@@ -8,22 +8,32 @@ void main() {
const uint row = i / p.ne20;
const uint col = i - row * p.ne20;
const uint i3 = row / (p.ne01 * p.ne02);
const uint i2 = (row % (p.ne01 * p.ne02)) / p.ne01;
const uint i1 = row % p.ne01;
const uint src_idx = i3 * p.nb03 + i2 * p.nb02 + i1 * p.nb01 + col;
const uint dst_i3 = row / (p.ne11 * p.ne12);
const uint dst_i2 = (row % (p.ne11 * p.ne12)) / p.ne11;
const uint dst_i1 = row % p.ne11;
const uint dst_idx = dst_i3 * p.nb13 + dst_i2 * p.nb12 + dst_i1 * p.nb11 + col;
if (p.mode == 0) {
// Default
const uint offset = p.ne00 / 2;
const uint idx = row * p.ne00 + col;
const uint idx = src_idx;
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
} else if (p.mode == 1) {
// Swapped
const uint offset = p.ne00 / 2;
const uint idx = row * p.ne00 + col;
const uint idx = src_idx;
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
data_d[dst_idx] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
} else {
// Split
const uint idx = row * p.ne00 + col;
const uint idx = src_idx;
data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
}
}

View File

@@ -0,0 +1,154 @@
{%- set image_count = namespace(value=0) %}
{%- set video_count = namespace(value=0) %}
{%- macro render_content(content, do_vision_count, is_system_content=false) %}
{%- if content is string %}
{{- content }}
{%- elif content is iterable and content is not mapping %}
{%- for item in content %}
{%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
{%- if is_system_content %}
{{- raise_exception('System message cannot contain images.') }}
{%- endif %}
{%- if do_vision_count %}
{%- set image_count.value = image_count.value + 1 %}
{%- endif %}
{%- if add_vision_id %}
{{- 'Picture ' ~ image_count.value ~ ': ' }}
{%- endif %}
{{- '<|vision_start|><|image_pad|><|vision_end|>' }}
{%- elif 'video' in item or item.type == 'video' %}
{%- if is_system_content %}
{{- raise_exception('System message cannot contain videos.') }}
{%- endif %}
{%- if do_vision_count %}
{%- set video_count.value = video_count.value + 1 %}
{%- endif %}
{%- if add_vision_id %}
{{- 'Video ' ~ video_count.value ~ ': ' }}
{%- endif %}
{{- '<|vision_start|><|video_pad|><|vision_end|>' }}
{%- elif 'text' in item %}
{{- item.text }}
{%- else %}
{{- raise_exception('Unexpected item type in content.') }}
{%- endif %}
{%- endfor %}
{%- elif content is none or content is undefined %}
{{- '' }}
{%- else %}
{{- raise_exception('Unexpected content type.') }}
{%- endif %}
{%- endmacro %}
{%- if not messages %}
{{- raise_exception('No messages provided.') }}
{%- endif %}
{%- if tools and tools is iterable and tools is not mapping %}
{{- '<|im_start|>system\n' }}
{{- "# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>" }}
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
{%- if messages[0].role == 'system' %}
{%- set content = render_content(messages[0].content, false, true)|trim %}
{%- if content %}
{{- '\n\n' + content }}
{%- endif %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- else %}
{%- if messages[0].role == 'system' %}
{%- set content = render_content(messages[0].content, false, true)|trim %}
{{- '<|im_start|>system\n' + content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" %}
{%- set content = render_content(message.content, false)|trim %}
{%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if ns.multi_step_tool %}
{{- raise_exception('No user query found in messages.') }}
{%- endif %}
{%- for message in messages %}
{%- set content = render_content(message.content, true)|trim %}
{%- if message.role == "system" %}
{%- if not loop.first %}
{{- raise_exception('System message must be at the beginning.') }}
{%- endif %}
{%- elif message.role == "user" %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set reasoning_content = '' %}
{%- if message.reasoning_content is string %}
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '</think>' in content %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- set reasoning_content = reasoning_content|trim %}
{%- if loop.index0 > ns.last_query_index %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{%- if loop.first %}
{%- if content|trim %}
{{- '\n\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- else %}
{{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- endif %}
{%- else %}
{{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- endif %}
{%- if tool_call.arguments is defined %}
{%- for args_name, args_value in tool_call.arguments|items %}
{{- '<parameter=' + args_name + '>\n' }}
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
{{- args_value }}
{{- '\n</parameter>\n' }}
{%- endfor %}
{%- endif %}
{{- '</function>\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.previtem and loop.previtem.role != "tool" %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- content }}
{{- '\n</tool_response>' }}
{%- if not loop.last and loop.nextitem.role != "tool" %}
{{- '<|im_end|>\n' }}
{%- elif loop.last %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- else %}
{{- raise_exception('Unexpected message role.') }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined and enable_thinking is false %}
{{- '<think>\n\n</think>\n\n' }}
{%- else %}
{{- '<think>\n' }}
{%- endif %}
{%- endif %}

View File

@@ -5,7 +5,7 @@ import os
import sys
import subprocess
HTTPLIB_VERSION = "refs/tags/v0.39.0"
HTTPLIB_VERSION = "refs/tags/v0.40.0"
vendor = {
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",

View File

@@ -425,6 +425,7 @@ static common_chat_tool special_function_tool_with_optional_param{
"required": ["arg1"]
})",
};
static common_chat_tool empty_args_tool{
/* .name = */ "empty_args",
/* .description = */ "A tool that takes no arguments",
@@ -433,6 +434,15 @@ static common_chat_tool empty_args_tool{
"properties": {}
})",
};
static common_chat_tool empty_args_tool_no_properties{
/* .name = */ "empty_args_no_props",
/* .description = */ "A tool that takes no arguments and has no properties",
/* .parameters = */ R"({
"type": "object"
})",
};
static common_chat_tool python_tool{
/* .name = */ "python",
/* .description = */ "an ipython interpreter",
@@ -1410,6 +1420,176 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
}
})";
{
// Qwen3.5 (basically same as Nemotron, but keeping separate tests just in case)
auto tst = peg_tester("models/templates/Qwen3.5-4B.jinja", detailed_debug);
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.expect(message_assist_thoughts)
.run();
tst.test("I'm\nthinking\n</think>\nHello, world!\nWhat's up?")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_NONE)
.expect_content("<think>\nI'm\nthinking\n</think>\nHello, world!\nWhat's up?")
.run();
tst.test("I'm\nthinking\n</think>\nHello, world!\nWhat's up?")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.expect(message_assist_thoughts)
.run();
tst.test(
"<tool_call>\n"
"<function=special_function>\n"
"<parameter=arg1>\n1\n</parameter>\n"
"</function>\n"
"</tool_call>")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
tst.test(
"I'm\nthinking\n</think>\n"
"<tool_call>\n"
"<function=special_function>\n"
"<parameter=arg1>\n1\n</parameter>\n"
"</function>\n"
"</tool_call>")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect(message_assist_call_thoughts)
.run();
tst.test(
"<tool_call>\n"
"<function=special_function>\n"
"<parameter=arg1>\n1\n</parameter>\n"
"</function>\n"
"</tool_call>\n"
"<tool_call>\n"
"<function=special_function_with_opt>\n"
"<parameter=arg1>\n1\n</parameter>\n"
"<parameter=arg2>\n2\n</parameter>\n"
"</function>\n"
"</tool_call>")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.parallel_tool_calls(true)
.tools({
special_function_tool, special_function_tool_with_optional_param
})
.expect_tool_calls({
{ "special_function", R"({"arg1": 1})", {} },
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
})
.run();
tst.test(
"<tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"def hello():\n"
" print(\"Hello, world!\")\n"
"\n"
"hello()\n"
"</parameter>\n"
"</function>\n"
"</tool_call>")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({
python_tool
})
.expect_tool_calls({
{ "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} },
})
.run();
tst.test(
"I need to output the invoice details in JSON\n"
"</think>\n"
R"({"amount": 123.45, "date": "2025-12-03"})")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.json_schema(invoice_schema)
.expect_reasoning("I need to output the invoice details in JSON")
.expect_content(R"({"amount": 123.45, "date": "2025-12-03"})")
.run();
// tool call segment in reasoning
tst.test(
"Let's call a tool: <tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"def hello():\n"
" print(\"Not the real call!\")\n"
"\n"
"hello()\n"
"</parameter>\n"
"</function>\n"
"</tool_call></think>\n"
"<tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"def hello():\n"
" print(\"Hello, world!\")\n"
"\n"
"hello()\n"
"</parameter>\n"
"</function>\n"
"</tool_call>"
)
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({
python_tool
})
.expect_reasoning("Let's call a tool: <tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"def hello():\n"
" print(\"Not the real call!\")\n"
"\n"
"hello()\n"
"</parameter>\n"
"</function>\n"
"</tool_call>")
.expect_tool_calls({
{ "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} },
})
.run();
// No args tool
tst.test(
"<tool_call>\n"
"<function=empty_args>\n"
"</function>\n"
"</tool_call>")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ empty_args_tool })
.expect(message_with_tool_calls("empty_args", "{}"))
.run();
// No args tool with no properties defined
tst.test(
"<tool_call>\n"
"<function=empty_args_no_props>\n"
"</function>\n"
"</tool_call>")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ empty_args_tool_no_properties })
.expect(message_with_tool_calls("empty_args_no_props", "{}"))
.run();
}
{
// Ministral-3-14B-Reasoning-2512
auto tst = peg_tester("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja", detailed_debug);
@@ -2796,6 +2976,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect(message_assist_thoughts)
.run();
// Analysis channel (reasoning) with final channel (content) with reasoning_format = none
tst.test(
"<|channel|>analysis<|message|>I'm\nthinking<|end|><|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's "
"up?")
.reasoning_format(COMMON_REASONING_FORMAT_NONE)
.expect_content("<|channel|>analysis<|message|>I'm\nthinking<|end|>Hello, world!\nWhat's up?")
.run();
// Analysis channel only (partial) - still works when reasoning format is set
tst.test("<|channel|>analysis<|message|>I'm\nthinking")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
@@ -2805,24 +2993,28 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Tool call with recipient in role header: " to=functions.NAME<|channel|>analysis<|message|>JSON"
tst.test(" to=functions.special_function<|channel|>analysis<|message|>{\"arg1\": 1}")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
// Tool call with recipient in channel header: "<|channel|>analysis to=functions.NAME<|message|>JSON"
tst.test("<|channel|>analysis to=functions.special_function<|message|>{\"arg1\": 1}")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
// Tool call with constraint: " to=functions.NAME<|channel|>analysis <|constrain|>json<|message|>JSON"
tst.test(" to=functions.special_function<|channel|>analysis <|constrain|>json<|message|>{\"arg1\": 1}")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
// Tool call in commentary channel (channel header variant)
tst.test("<|channel|>commentary to=functions.special_function<|message|>{\"arg1\": 1}")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect(message_assist_call)
.run();

View File

@@ -1525,6 +1525,47 @@ int main() {
}
});
// C++ only tests (features not yet supported in JS/Python implementations)
{
fprintf(stderr, "#\n# Testing C++ only features\n#\n");
auto run = [](const TestCase & tc) {
fprintf(stderr, "- %s\n", tc.name.c_str());
try {
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true));
tc.verify_status(SUCCESS);
} catch (const std::invalid_argument & ex) {
fprintf(stderr, "Error: %s\n", ex.what());
tc.verify_status(FAILURE);
}
};
run({
SUCCESS,
"regexp with non-capturing group",
R"""({
"type": "string",
"pattern": "^(?:foo|bar)baz$"
})""",
R"""(
root ::= "\"" (("foo" | "bar") "baz") "\"" space
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""",
});
run({
SUCCESS,
"regexp with nested non-capturing groups",
R"""({
"type": "string",
"pattern": "^(?:(?:ab)+c)?d$"
})""",
R"""(
root ::= "\"" ((("ab")+ "c")? "d") "\"" space
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""",
});
}
if (getenv("LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR")) {
fprintf(stderr, "\033[33mWARNING: Skipping slow tests on emulator.\n\033[0m");
} else {

View File

@@ -1775,6 +1775,16 @@ Apart from error types supported by OAI, we also have custom types that are spec
}
```
### Custom default Web UI preferences
You can specify default preferences for the web UI using `--webui-config <JSON config>` or `--webui-config-file <path to JSON config>`. For example, you can disable pasting long text as attachments and enable rendering Markdown in user messages with this command:
```bash
./llama-server -m model.gguf --webui-config '{"pasteLongTextToFileLen": 0, "renderUserContentAsMarkdown": true}'
```
You may find available preferences in [settings-config.ts](webui/src/lib/constants/settings-config.ts).
### Legacy completion web UI
A new chat-based UI has replaced the old completion-based since [this PR](https://github.com/ggml-org/llama.cpp/pull/10175). If you want to use the old completion, start the server with `--path ./tools/server/public_legacy`

Binary file not shown.

View File

@@ -2493,7 +2493,7 @@ private:
bool has_mtmd = false;
// check if we should process the image
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
// process the image
size_t n_tokens_out = 0;
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);

View File

@@ -113,16 +113,10 @@ bool server_http_context::init(const common_params & params) {
srv->set_read_timeout (params.timeout_read);
srv->set_write_timeout(params.timeout_write);
srv->set_socket_options([reuse_port = params.reuse_port](socket_t sock) {
int opt = 1;
#ifdef _WIN32
const char * optval = (const char *)&opt;
#else
const void * optval = &opt;
#endif
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, optval, sizeof(opt));
httplib::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
if (reuse_port) {
#ifdef SO_REUSEPORT
setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, optval, sizeof(opt));
httplib::set_socket_opt(sock, SOL_SOCKET, SO_REUSEPORT, 1);
#else
LOG_WRN("%s: SO_REUSEPORT is not supported\n", __func__);
#endif

View File

@@ -288,7 +288,15 @@ class ServerProcess:
server_instances.remove(self)
if self.process:
print(f"Stopping server with pid={self.process.pid}")
self.process.kill()
self.process.terminate()
try:
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
print(f"Server pid={self.process.pid} did not terminate in time, killing")
self.process.kill()
self.process.wait(timeout=5)
except Exception as e:
print(f"Error waiting for server: {e}")
self.process = None
def make_request(

View File

@@ -10,9 +10,9 @@
ModelsSelector,
ModelsSelectorSheet
} from '$lib/components/app';
import { DialogChatSettings } from '$lib/components/app/dialogs';
import { SETTINGS_SECTION_TITLES } from '$lib/constants';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { getChatSettingsDialogContext } from '$lib/contexts';
import { FileTypeCategory } from '$lib/enums';
import { getFileTypeCategory } from '$lib/utils';
import { config } from '$lib/stores/settings.svelte';
@@ -169,7 +169,7 @@
selectorModelRef?.open();
}
let showChatSettingsDialogWithMcpSection = $state(false);
const chatSettingsDialog = getChatSettingsDialogContext();
let hasMcpPromptsSupport = $derived.by(() => {
const perChatOverrides = conversationsStore.getAllMcpServerOverrides();
@@ -197,7 +197,7 @@
{onSystemPromptClick}
{onMcpPromptClick}
{onMcpResourcesClick}
onMcpSettingsClick={() => (showChatSettingsDialogWithMcpSection = true)}
onMcpSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
/>
{:else}
<ChatFormActionAttachmentsDropdown
@@ -210,13 +210,13 @@
{onSystemPromptClick}
{onMcpPromptClick}
{onMcpResourcesClick}
onMcpSettingsClick={() => (showChatSettingsDialogWithMcpSection = true)}
onMcpSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
/>
{/if}
<McpServersSelector
{disabled}
onSettingsClick={() => (showChatSettingsDialogWithMcpSection = true)}
onSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
/>
</div>
@@ -265,9 +265,3 @@
/>
{/if}
</div>
<DialogChatSettings
open={showChatSettingsDialogWithMcpSection}
onOpenChange={(open) => (showChatSettingsDialogWithMcpSection = open)}
initialSection={SETTINGS_SECTION_TITLES.MCP}
/>

View File

@@ -180,6 +180,10 @@
chatActions.continueAssistantMessage(message);
}
function handleForkConversation(options: { name: string; includeAttachments: boolean }) {
chatActions.forkConversation(message, options);
}
function handleNavigateToSibling(siblingId: string) {
chatActions.navigateToSibling(siblingId);
}
@@ -285,6 +289,7 @@
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onForkConversation={handleForkConversation}
onNavigateToSibling={handleNavigateToSibling}
onShowDeleteDialogChange={handleShowDeleteDialogChange}
{showDeleteDialog}
@@ -303,6 +308,7 @@
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onForkConversation={handleForkConversation}
onNavigateToSibling={handleNavigateToSibling}
onRegenerate={handleRegenerate}
onShowDeleteDialogChange={handleShowDeleteDialogChange}

View File

@@ -1,12 +1,16 @@
<script lang="ts">
import { Edit, Copy, RefreshCw, Trash2, ArrowRight } from '@lucide/svelte';
import { Edit, Copy, RefreshCw, Trash2, ArrowRight, GitBranch } from '@lucide/svelte';
import {
ActionIcon,
ChatMessageBranchingControls,
DialogConfirmation
} from '$lib/components/app';
import { Switch } from '$lib/components/ui/switch';
import { Checkbox } from '$lib/components/ui/checkbox';
import Input from '$lib/components/ui/input/input.svelte';
import Label from '$lib/components/ui/label/label.svelte';
import { MessageRole } from '$lib/enums';
import { activeConversation } from '$lib/stores/conversations.svelte';
interface Props {
role: MessageRole.USER | MessageRole.ASSISTANT;
@@ -24,6 +28,7 @@
onEdit?: () => void;
onRegenerate?: () => void;
onContinue?: () => void;
onForkConversation?: (options: { name: string; includeAttachments: boolean }) => void;
onDelete: () => void;
onConfirmDelete: () => void;
onNavigateToSibling?: (siblingId: string) => void;
@@ -42,6 +47,7 @@
onConfirmDelete,
onContinue,
onDelete,
onForkConversation,
onNavigateToSibling,
onShowDeleteDialogChange,
onRegenerate,
@@ -53,10 +59,27 @@
onRawOutputToggle
}: Props = $props();
let showForkDialog = $state(false);
let forkName = $state('');
let forkIncludeAttachments = $state(true);
function handleConfirmDelete() {
onConfirmDelete();
onShowDeleteDialogChange(false);
}
function handleOpenForkDialog() {
const conv = activeConversation();
forkName = `Fork of ${conv?.name ?? 'Conversation'}`;
forkIncludeAttachments = true;
showForkDialog = true;
}
function handleConfirmFork() {
onForkConversation?.({ name: forkName.trim(), includeAttachments: forkIncludeAttachments });
showForkDialog = false;
}
</script>
<div class="relative {justify === 'start' ? 'mt-2' : ''} flex h-6 items-center justify-between">
@@ -86,6 +109,10 @@
<ActionIcon icon={ArrowRight} tooltip="Continue" onclick={onContinue} />
{/if}
{#if onForkConversation}
<ActionIcon icon={GitBranch} tooltip="Fork conversation" onclick={handleOpenForkDialog} />
{/if}
<ActionIcon icon={Trash2} tooltip="Delete" onclick={onDelete} />
</div>
</div>
@@ -116,3 +143,42 @@
onConfirm={handleConfirmDelete}
onCancel={() => onShowDeleteDialogChange(false)}
/>
<DialogConfirmation
bind:open={showForkDialog}
title="Fork Conversation"
description="Create a new conversation branching from this message."
confirmText="Fork"
cancelText="Cancel"
icon={GitBranch}
onConfirm={handleConfirmFork}
onCancel={() => (showForkDialog = false)}
>
<div class="flex flex-col gap-4 py-2">
<div class="flex flex-col gap-2">
<Label for="fork-name">Title</Label>
<Input
id="fork-name"
class="text-foreground"
placeholder="Enter fork name"
type="text"
bind:value={forkName}
/>
</div>
<div class="flex items-center gap-2">
<Checkbox
id="fork-attachments"
checked={forkIncludeAttachments}
onCheckedChange={(checked) => {
forkIncludeAttachments = checked === true;
}}
/>
<Label for="fork-attachments" class="cursor-pointer text-sm font-normal">
Include all attachments
</Label>
</div>
</div>
</DialogConfirmation>

View File

@@ -39,6 +39,7 @@
onContinue?: () => void;
onDelete: () => void;
onEdit?: () => void;
onForkConversation?: (options: { name: string; includeAttachments: boolean }) => void;
onNavigateToSibling?: (siblingId: string) => void;
onRegenerate: (modelOverride?: string) => void;
onShowDeleteDialogChange: (show: boolean) => void;
@@ -58,6 +59,7 @@
onCopy,
onDelete,
onEdit,
onForkConversation,
onNavigateToSibling,
onRegenerate,
onShowDeleteDialogChange,
@@ -345,6 +347,7 @@
onContinue={currentConfig.enableContinueGeneration && !hasReasoningMarkers
? onContinue
: undefined}
{onForkConversation}
{onDelete}
{onConfirmDelete}
{onNavigateToSibling}

View File

@@ -21,6 +21,7 @@
onEdit: () => void;
onDelete: () => void;
onConfirmDelete: () => void;
onForkConversation?: (options: { name: string; includeAttachments: boolean }) => void;
onShowDeleteDialogChange: (show: boolean) => void;
onNavigateToSibling?: (siblingId: string) => void;
onCopy: () => void;
@@ -35,6 +36,7 @@
onEdit,
onDelete,
onConfirmDelete,
onForkConversation,
onShowDeleteDialogChange,
onNavigateToSibling,
onCopy
@@ -114,6 +116,7 @@
{onCopy}
{onDelete}
{onEdit}
{onForkConversation}
{onNavigateToSibling}
{onShowDeleteDialogChange}
{siblingInfo}

View File

@@ -79,6 +79,13 @@
onUserAction?.();
await chatStore.continueAssistantMessage(message.id);
refreshAllMessages();
},
forkConversation: async (
message: DatabaseMessage,
options: { name: string; includeAttachments: boolean }
) => {
await conversationsStore.forkConversation(message.id, options);
}
});

View File

@@ -1,16 +1,11 @@
<script lang="ts">
import { Settings } from '@lucide/svelte';
import { DialogChatSettings } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import { useSidebar } from '$lib/components/ui/sidebar';
import { getChatSettingsDialogContext } from '$lib/contexts';
const sidebar = useSidebar();
let settingsOpen = $state(false);
function toggleSettings() {
settingsOpen = true;
}
const chatSettingsDialog = getChatSettingsDialogContext();
</script>
<header
@@ -22,12 +17,10 @@
<Button
variant="ghost"
size="icon-lg"
onclick={toggleSettings}
onclick={() => chatSettingsDialog.open()}
class="rounded-full backdrop-blur-lg"
>
<Settings class="h-4 w-4" />
</Button>
</div>
</header>
<DialogChatSettings open={settingsOpen} onOpenChange={(open) => (settingsOpen = open)} />

View File

@@ -1,13 +1,18 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { page } from '$app/state';
import { Trash2 } from '@lucide/svelte';
import { Trash2, Pencil } from '@lucide/svelte';
import { ChatSidebarConversationItem, DialogConfirmation } from '$lib/components/app';
import { Checkbox } from '$lib/components/ui/checkbox';
import Label from '$lib/components/ui/label/label.svelte';
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
import * as Sidebar from '$lib/components/ui/sidebar';
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import Input from '$lib/components/ui/input/input.svelte';
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
import {
conversationsStore,
conversations,
buildConversationTree
} from '$lib/stores/conversations.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { getPreviewText } from '$lib/utils';
import ChatSidebarActions from './ChatSidebarActions.svelte';
@@ -18,6 +23,7 @@
let isSearchModeActive = $state(false);
let searchQuery = $state('');
let showDeleteDialog = $state(false);
let deleteWithForks = $state(false);
let showEditDialog = $state(false);
let selectedConversation = $state<DatabaseConversation | null>(null);
let editedName = $state('');
@@ -35,10 +41,30 @@
return conversations();
});
let conversationTree = $derived(buildConversationTree(filteredConversations));
let selectedConversationHasDescendants = $derived.by(() => {
if (!selectedConversation) return false;
const allConvs = conversations();
const queue = [selectedConversation.id];
while (queue.length > 0) {
const parentId = queue.pop()!;
for (const c of allConvs) {
if (c.forkedFromConversationId === parentId) return true;
}
}
return false;
});
async function handleDeleteConversation(id: string) {
const conversation = conversations().find((conv) => conv.id === id);
if (conversation) {
selectedConversation = conversation;
deleteWithForks = false;
showDeleteDialog = true;
}
}
@@ -54,11 +80,14 @@
function handleConfirmDelete() {
if (selectedConversation) {
const convId = selectedConversation.id;
const withForks = deleteWithForks;
showDeleteDialog = false;
setTimeout(() => {
conversationsStore.deleteConversation(selectedConversation.id);
selectedConversation = null;
conversationsStore.deleteConversation(convId, {
deleteWithForks: withForks
});
}, 100); // Wait for animation to finish
}
}
@@ -110,7 +139,7 @@
</script>
<ScrollArea class="h-[100vh]">
<Sidebar.Header class=" top-0 z-10 gap-6 bg-sidebar/50 px-4 py-4 pb-2 backdrop-blur-lg md:sticky">
<Sidebar.Header class=" top-0 z-10 gap-4 bg-sidebar/50 p-4 pb-2 backdrop-blur-lg md:sticky">
<a href="#/" onclick={handleMobileSidebarItemClick}>
<h1 class="inline-flex items-center gap-1 px-2 text-xl font-semibold">llama.cpp</h1>
</a>
@@ -118,7 +147,7 @@
<ChatSidebarActions {handleMobileSidebarItemClick} bind:isSearchModeActive bind:searchQuery />
</Sidebar.Header>
<Sidebar.Group class="mt-4 space-y-2 p-0 px-4">
<Sidebar.Group class="mt-2 space-y-2 p-0 px-4">
{#if (filteredConversations.length > 0 && isSearchModeActive) || !isSearchModeActive}
<Sidebar.GroupLabel>
{isSearchModeActive ? 'Search results' : 'Conversations'}
@@ -127,15 +156,17 @@
<Sidebar.GroupContent>
<Sidebar.Menu>
{#each filteredConversations as conversation (conversation.id)}
<Sidebar.MenuItem class="mb-1">
{#each conversationTree as { conversation, depth } (conversation.id)}
<Sidebar.MenuItem class="mb-1 p-0">
<ChatSidebarConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId
}}
{depth}
{handleMobileSidebarItemClick}
isActive={currentChatId === conversation.id}
onSelect={selectConversation}
@@ -146,7 +177,7 @@
</Sidebar.MenuItem>
{/each}
{#if filteredConversations.length === 0}
{#if conversationTree.length === 0}
<div class="px-2 py-4 text-center">
<p class="mb-4 p-4 text-sm text-muted-foreground">
{searchQuery.length > 0
@@ -177,35 +208,40 @@
showDeleteDialog = false;
selectedConversation = null;
}}
/>
>
{#if selectedConversationHasDescendants}
<div class="flex items-center gap-2 py-2">
<Checkbox id="delete-with-forks" bind:checked={deleteWithForks} />
<AlertDialog.Root bind:open={showEditDialog}>
<AlertDialog.Content>
<AlertDialog.Header>
<AlertDialog.Title>Edit Conversation Name</AlertDialog.Title>
<AlertDialog.Description>
<Input
class="mt-4 text-foreground"
onkeydown={(e) => {
if (e.key === 'Enter') {
e.preventDefault();
handleConfirmEdit();
}
}}
placeholder="Enter a new name"
type="text"
bind:value={editedName}
/>
</AlertDialog.Description>
</AlertDialog.Header>
<AlertDialog.Footer>
<AlertDialog.Cancel
onclick={() => {
showEditDialog = false;
selectedConversation = null;
}}>Cancel</AlertDialog.Cancel
>
<AlertDialog.Action onclick={handleConfirmEdit}>Save</AlertDialog.Action>
</AlertDialog.Footer>
</AlertDialog.Content>
</AlertDialog.Root>
<Label for="delete-with-forks" class="text-sm">Also delete all forked conversations</Label>
</div>
{/if}
</DialogConfirmation>
<DialogConfirmation
bind:open={showEditDialog}
title="Edit Conversation Name"
description=""
confirmText="Save"
cancelText="Cancel"
icon={Pencil}
onConfirm={handleConfirmEdit}
onCancel={() => {
showEditDialog = false;
selectedConversation = null;
}}
onKeydown={(e) => {
if (e.key === 'Enter') {
e.preventDefault();
e.stopImmediatePropagation();
handleConfirmEdit();
}
}}
>
<Input
class="text-foreground"
placeholder="Enter a new name"
type="text"
bind:value={editedName}
/>
</DialogConfirmation>

View File

@@ -3,6 +3,9 @@
import { KeyboardShortcutInfo } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import { Input } from '$lib/components/ui/input';
import { McpLogo } from '$lib/components/app';
import { SETTINGS_SECTION_TITLES } from '$lib/constants';
import { getChatSettingsDialogContext } from '$lib/contexts';
interface Props {
handleMobileSidebarItemClick: () => void;
@@ -18,6 +21,8 @@
let searchInput: HTMLInputElement | null = $state(null);
const chatSettingsDialog = getChatSettingsDialogContext();
function handleSearchModeDeactivate() {
isSearchModeActive = false;
searchQuery = '';
@@ -30,7 +35,7 @@
});
</script>
<div class="space-y-0.5">
<div class="my-1 space-y-1">
{#if isSearchModeActive}
<div class="relative">
<Search class="absolute top-2.5 left-2 h-4 w-4 text-muted-foreground" />
@@ -50,13 +55,14 @@
</div>
{:else}
<Button
class="w-full justify-between hover:[&>kbd]:opacity-100"
class="w-full justify-between backdrop-blur-none! hover:[&>kbd]:opacity-100"
href="?new_chat=true#/"
onclick={handleMobileSidebarItemClick}
variant="ghost"
>
<div class="flex items-center gap-2">
<SquarePen class="h-4 w-4" />
New chat
</div>
@@ -64,7 +70,7 @@
</Button>
<Button
class="w-full justify-between hover:[&>kbd]:opacity-100"
class="w-full justify-between backdrop-blur-none! hover:[&>kbd]:opacity-100"
onclick={() => {
isSearchModeActive = true;
}}
@@ -72,10 +78,25 @@
>
<div class="flex items-center gap-2">
<Search class="h-4 w-4" />
Search conversations
Search
</div>
<KeyboardShortcutInfo keys={['cmd', 'k']} />
</Button>
<Button
class="w-full justify-between backdrop-blur-none! hover:[&>kbd]:opacity-100"
onclick={() => {
chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP);
}}
variant="ghost"
>
<div class="flex items-center gap-2">
<McpLogo class="h-4 w-4" />
MCP Servers
</div>
</Button>
{/if}
</div>

View File

@@ -1,13 +1,23 @@
<script lang="ts">
import { Trash2, Pencil, MoreHorizontal, Download, Loader2, Square } from '@lucide/svelte';
import {
Trash2,
Pencil,
MoreHorizontal,
Download,
Loader2,
Square,
GitBranch
} from '@lucide/svelte';
import { DropdownMenuActions } from '$lib/components/app';
import * as Tooltip from '$lib/components/ui/tooltip';
import { FORK_TREE_DEPTH_PADDING } from '$lib/constants';
import { getAllLoadingChats } from '$lib/stores/chat.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { onMount } from 'svelte';
interface Props {
isActive?: boolean;
depth?: number;
conversation: DatabaseConversation;
handleMobileSidebarItemClick?: () => void;
onDelete?: (id: string) => void;
@@ -23,7 +33,8 @@
onEdit,
onSelect,
onStop,
isActive = false
isActive = false,
depth = 0
}: Props = $props();
let renderActionsDropdown = $state(false);
@@ -88,14 +99,34 @@
<!-- svelte-ignore a11y_mouse_events_have_key_events -->
<button
class="group flex min-h-9 w-full cursor-pointer items-center justify-between space-x-3 rounded-lg px-3 py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive
class="group flex min-h-9 w-full cursor-pointer items-center justify-between space-x-3 rounded-lg py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive
? 'bg-foreground/5 text-accent-foreground'
: ''}"
: ''} px-3"
onclick={handleSelect}
onmouseover={handleMouseOver}
onmouseleave={handleMouseLeave}
>
<div class="flex min-w-0 flex-1 items-center gap-2">
<div
class="flex min-w-0 flex-1 items-center gap-2"
style:padding-left="{depth * FORK_TREE_DEPTH_PADDING}px"
>
{#if depth > 0}
<Tooltip.Root>
<Tooltip.Trigger>
<a
href="#/chat/{conversation.forkedFromConversationId}"
class="flex shrink-0 items-center text-muted-foreground transition-colors hover:text-foreground"
>
<GitBranch class="h-3.5 w-3.5" />
</a>
</Tooltip.Trigger>
<Tooltip.Content>
<p>See parent conversation</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
{#if isLoading}
<Tooltip.Root>
<Tooltip.Trigger>

View File

@@ -1,6 +1,6 @@
<script lang="ts">
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import type { Component } from 'svelte';
import type { Component, Snippet } from 'svelte';
import { KeyboardKey } from '$lib/enums';
interface Props {
@@ -14,6 +14,7 @@
onConfirm: () => void;
onCancel: () => void;
onKeydown?: (event: KeyboardEvent) => void;
children?: Snippet;
}
let {
@@ -26,7 +27,8 @@
icon,
onConfirm,
onCancel,
onKeydown
onKeydown,
children
}: Props = $props();
function handleKeydown(event: KeyboardEvent) {
@@ -60,6 +62,10 @@
</AlertDialog.Description>
</AlertDialog.Header>
{#if children}
{@render children()}
{/if}
<AlertDialog.Footer>
<AlertDialog.Cancel onclick={onCancel}>{cancelText}</AlertDialog.Cancel>
<AlertDialog.Action

View File

@@ -18,7 +18,8 @@
showRaw = undefined,
aliases,
tags,
class: className = ''
class: className = '',
...rest
}: Props = $props();
const badgeClass =
@@ -36,9 +37,9 @@
</script>
{#if resolvedShowRaw}
<TruncatedText class="font-medium {className}" showTooltip={false} text={modelId} />
<TruncatedText class="font-medium {className}" showTooltip={false} text={modelId} {...rest} />
{:else}
<span class="flex min-w-0 flex-wrap items-center gap-1 {className}">
<span class="flex min-w-0 flex-wrap items-center gap-1 {className}" {...rest}>
<span class="min-w-0 truncate font-medium">
{#if showOrgName && parsed.orgName && !(aliases && aliases.length > 0)}{parsed.orgName}/{/if}{displayName}
</span>

View File

@@ -271,50 +271,49 @@
{#if isRouter}
<DropdownMenu.Root bind:open={isOpen} onOpenChange={handleOpenChange}>
<DropdownMenu.Trigger
disabled={disabled || updating}
onclick={(e) => {
e.preventDefault();
e.stopPropagation();
}}
>
<button
type="button"
class={cn(
`inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-muted-foreground/10 px-1.5 py-1 text-xs transition hover:text-foreground focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60`,
!isCurrentModelInCache
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
: forceForegroundText
class={cn(
`inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-muted-foreground/10 px-1.5 py-1 text-xs transition hover:text-foreground focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60`,
!isCurrentModelInCache
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
: forceForegroundText
? 'text-foreground'
: isHighlightedCurrentModelActive
? 'text-foreground'
: isHighlightedCurrentModelActive
? 'text-foreground'
: 'text-muted-foreground',
isOpen ? 'text-foreground' : ''
)}
style="max-width: min(calc(100cqw - 9rem), 20rem)"
disabled={disabled || updating}
>
<Package class="h-3.5 w-3.5" />
: 'text-muted-foreground',
isOpen ? 'text-foreground' : ''
)}
style="max-width: min(calc(100cqw - 9rem), 20rem)"
disabled={disabled || updating}
>
<Package class="h-3.5 w-3.5" />
{#if selectedOption}
<Tooltip.Root>
<Tooltip.Trigger class="min-w-0 overflow-hidden">
<ModelId modelId={selectedOption.model} class="min-w-0" showOrgName />
</Tooltip.Trigger>
{#if selectedOption}
<Tooltip.Root>
<Tooltip.Trigger>
<!-- prevent another nested button element -->
{#snippet child({ props })}
<ModelId
modelId={selectedOption.model}
class="min-w-0 overflow-hidden"
showOrgName
{...props}
/>
{/snippet}
</Tooltip.Trigger>
<Tooltip.Content>
<p class="font-mono">{selectedOption.model}</p>
</Tooltip.Content>
</Tooltip.Root>
{:else}
<span class="min-w-0 font-medium">Select model</span>
{/if}
<Tooltip.Content>
<p class="font-mono">{selectedOption.model}</p>
</Tooltip.Content>
</Tooltip.Root>
{:else}
<span class="min-w-0 font-medium">Select model</span>
{/if}
{#if updating || isLoadingModel}
<Loader2 class="h-3 w-3.5 animate-spin" />
{:else}
<ChevronDown class="h-3 w-3.5" />
{/if}
</button>
{#if updating || isLoadingModel}
<Loader2 class="h-3 w-3.5 animate-spin" />
{:else}
<ChevronDown class="h-3 w-3.5" />
{/if}
</DropdownMenu.Trigger>
<DropdownMenu.Content
@@ -407,8 +406,16 @@
{#if selectedOption}
<Tooltip.Root>
<Tooltip.Trigger class="min-w-0 overflow-hidden">
<ModelId modelId={selectedOption.model} class="min-w-0" showOrgName />
<Tooltip.Trigger>
<!-- prevent another nested button element -->
{#snippet child({ props })}
<ModelId
modelId={selectedOption.model}
class="min-w-0 overflow-hidden"
showOrgName
{...props}
/>
{/snippet}
</Tooltip.Trigger>
<Tooltip.Content>

View File

@@ -0,0 +1,3 @@
export const CONTEXT_KEY_MESSAGE_EDIT = 'chat-message-edit';
export const CONTEXT_KEY_CHAT_ACTIONS = 'chat-actions';
export const CONTEXT_KEY_CHAT_SETTINGS_DIALOG = 'chat-settings-dialog';

View File

@@ -10,6 +10,7 @@ export * from './cache';
export * from './chat-form';
export * from './code-blocks';
export * from './code';
export * from './context-keys';
export * from './css-classes';
export * from './favicon';
export * from './floating-ui-constraints';

View File

@@ -1 +1,2 @@
export const FORK_TREE_DEPTH_PADDING = 8;
export const SYSTEM_MESSAGE_PLACEHOLDER = 'System message';

View File

@@ -1,4 +1,5 @@
import { getContext, setContext } from 'svelte';
import { CONTEXT_KEY_CHAT_ACTIONS } from '$lib/constants';
export interface ChatActionsContext {
copy: (message: DatabaseMessage) => void;
@@ -21,9 +22,13 @@ export interface ChatActionsContext {
) => void;
regenerateWithBranching: (message: DatabaseMessage, modelOverride?: string) => void;
continueAssistantMessage: (message: DatabaseMessage) => void;
forkConversation: (
message: DatabaseMessage,
options: { name: string; includeAttachments: boolean }
) => void;
}
const CHAT_ACTIONS_KEY = Symbol.for('chat-actions');
const CHAT_ACTIONS_KEY = Symbol.for(CONTEXT_KEY_CHAT_ACTIONS);
export function setChatActionsContext(ctx: ChatActionsContext): ChatActionsContext {
return setContext(CHAT_ACTIONS_KEY, ctx);

View File

@@ -0,0 +1,19 @@
import { getContext, setContext } from 'svelte';
import type { SettingsSectionTitle } from '$lib/constants';
import { CONTEXT_KEY_CHAT_SETTINGS_DIALOG } from '$lib/constants';
export interface ChatSettingsDialogContext {
open: (initialSection?: SettingsSectionTitle) => void;
}
const CHAT_SETTINGS_DIALOG_KEY = Symbol.for(CONTEXT_KEY_CHAT_SETTINGS_DIALOG);
export function setChatSettingsDialogContext(
ctx: ChatSettingsDialogContext
): ChatSettingsDialogContext {
return setContext(CHAT_SETTINGS_DIALOG_KEY, ctx);
}
export function getChatSettingsDialogContext(): ChatSettingsDialogContext {
return getContext(CHAT_SETTINGS_DIALOG_KEY);
}

View File

@@ -11,3 +11,9 @@ export {
setChatActionsContext,
type ChatActionsContext
} from './chat-actions.context';
export {
getChatSettingsDialogContext,
setChatSettingsDialogContext,
type ChatSettingsDialogContext
} from './chat-settings-dialog.context';

View File

@@ -1,4 +1,5 @@
import { getContext, setContext } from 'svelte';
import { CONTEXT_KEY_MESSAGE_EDIT } from '$lib/constants';
export interface MessageEditState {
readonly isEditing: boolean;
@@ -22,7 +23,7 @@ export interface MessageEditActions {
export type MessageEditContext = MessageEditState & MessageEditActions;
const MESSAGE_EDIT_KEY = Symbol.for('chat-message-edit');
const MESSAGE_EDIT_KEY = Symbol.for(CONTEXT_KEY_MESSAGE_EDIT);
/**
* Sets the message edit context. Call this in the parent component (ChatMessage.svelte).

View File

@@ -1,5 +1,6 @@
import Dexie, { type EntityTable } from 'dexie';
import { findDescendantMessages, uuid } from '$lib/utils';
import { findDescendantMessages, uuid, filterByLeafNodeId } from '$lib/utils';
import type { McpServerOverride } from '$lib/types/database';
class LlamacppDatabase extends Dexie {
conversations!: EntityTable<DatabaseConversation, string>;
@@ -173,8 +174,47 @@ export class DatabaseService {
*
* @param id - Conversation ID
*/
static async deleteConversation(id: string): Promise<void> {
static async deleteConversation(
id: string,
options?: { deleteWithForks?: boolean }
): Promise<void> {
await db.transaction('rw', [db.conversations, db.messages], async () => {
if (options?.deleteWithForks) {
// Recursively collect all descendant IDs
const idsToDelete: string[] = [];
const queue = [id];
while (queue.length > 0) {
const parentId = queue.pop()!;
const children = await db.conversations
.filter((c) => c.forkedFromConversationId === parentId)
.toArray();
for (const child of children) {
idsToDelete.push(child.id);
queue.push(child.id);
}
}
for (const forkId of idsToDelete) {
await db.conversations.delete(forkId);
await db.messages.where('convId').equals(forkId).delete();
}
} else {
// Reparent direct children to deleted conv's parent
const conv = await db.conversations.get(id);
const newParent = conv?.forkedFromConversationId;
const directChildren = await db.conversations
.filter((c) => c.forkedFromConversationId === id)
.toArray();
for (const child of directChildren) {
await db.conversations.update(child.id, {
forkedFromConversationId: newParent ?? undefined
});
}
}
await db.conversations.delete(id);
await db.messages.where('convId').equals(id).delete();
});
@@ -364,4 +404,88 @@ export class DatabaseService {
return { imported: importedCount, skipped: skippedCount };
});
}
/**
*
*
* Forking
*
*
*/
/**
* Forks a conversation at a specific message, creating a new conversation
* containing all messages from the root up to (and including) the target message.
*
* @param sourceConvId - The source conversation ID
* @param atMessageId - The message ID to fork at (the new conversation ends here)
* @param options - Fork options (name and whether to include attachments)
* @returns The newly created conversation
*/
static async forkConversation(
sourceConvId: string,
atMessageId: string,
options: { name: string; includeAttachments: boolean }
): Promise<DatabaseConversation> {
return await db.transaction('rw', [db.conversations, db.messages], async () => {
const sourceConv = await db.conversations.get(sourceConvId);
if (!sourceConv) {
throw new Error(`Source conversation ${sourceConvId} not found`);
}
const allMessages = await db.messages.where('convId').equals(sourceConvId).toArray();
const pathMessages = filterByLeafNodeId(allMessages, atMessageId, true) as DatabaseMessage[];
if (pathMessages.length === 0) {
throw new Error(`Could not resolve message path to ${atMessageId}`);
}
const idMap = new Map<string, string>();
for (const msg of pathMessages) {
idMap.set(msg.id, uuid());
}
const newConvId = uuid();
const clonedMessages: DatabaseMessage[] = pathMessages.map((msg) => {
const newId = idMap.get(msg.id)!;
const newParent = msg.parent ? (idMap.get(msg.parent) ?? null) : null;
const newChildren = msg.children
.filter((childId: string) => idMap.has(childId))
.map((childId: string) => idMap.get(childId)!);
return {
...msg,
id: newId,
convId: newConvId,
parent: newParent,
children: newChildren,
extra: options.includeAttachments ? msg.extra : undefined
};
});
const lastClonedMessage = clonedMessages[clonedMessages.length - 1];
const newConv: DatabaseConversation = {
id: newConvId,
name: options.name,
lastModified: Date.now(),
currNode: lastClonedMessage.id,
forkedFromConversationId: sourceConvId,
mcpServerOverrides: sourceConv.mcpServerOverrides
? sourceConv.mcpServerOverrides.map((o: McpServerOverride) => ({
serverId: o.serverId,
enabled: o.enabled
}))
: undefined
};
await db.conversations.add(newConv);
for (const msg of clonedMessages) {
await db.messages.add(msg);
}
return newConv;
});
}
}

View File

@@ -1265,35 +1265,53 @@ class ChatStore {
let result = this.getMessageByIdWithRole(messageId, MessageRole.USER);
if (!result) result = this.getMessageByIdWithRole(messageId, MessageRole.SYSTEM);
if (!result) return;
const { message: msg } = result;
const { message: msg, index: idx } = result;
try {
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
const isFirstUserMessage =
msg.role === MessageRole.USER && rootMessage && msg.parent === rootMessage.id;
const parentId = msg.parent || rootMessage?.id;
if (!parentId) return;
const extrasToUse =
newExtras !== undefined
? JSON.parse(JSON.stringify(newExtras))
: msg.extra
? JSON.parse(JSON.stringify(msg.extra))
: undefined;
const newMessage = await DatabaseService.createMessageBranch(
{
convId: msg.convId,
type: msg.type,
timestamp: Date.now(),
role: msg.role,
let messageIdForResponse: string;
if (msg.children.length === 0) {
// No responses after this message — update in place instead of branching
const updates: Partial<DatabaseMessage> = {
content: newContent,
toolCalls: msg.toolCalls || '',
children: [],
extra: extrasToUse,
model: msg.model
},
parentId
);
await conversationsStore.updateCurrentNode(newMessage.id);
timestamp: Date.now(),
extra: extrasToUse
};
await DatabaseService.updateMessage(msg.id, updates);
conversationsStore.updateMessageAtIndex(idx, updates);
messageIdForResponse = msg.id;
} else {
// Has children — create a new branch as sibling
const parentId = msg.parent || rootMessage?.id;
if (!parentId) return;
const newMessage = await DatabaseService.createMessageBranch(
{
convId: msg.convId,
type: msg.type,
timestamp: Date.now(),
role: msg.role,
content: newContent,
toolCalls: msg.toolCalls || '',
children: [],
extra: extrasToUse,
model: msg.model
},
parentId
);
await conversationsStore.updateCurrentNode(newMessage.id);
messageIdForResponse = newMessage.id;
}
conversationsStore.updateConversationTimestamp();
if (isFirstUserMessage && newContent.trim())
await conversationsStore.updateConversationTitleWithConfirmation(
@@ -1301,7 +1319,8 @@ class ChatStore {
newContent.trim()
);
await conversationsStore.refreshActiveMessages();
if (msg.role === MessageRole.USER) await this.generateResponseForMessage(newMessage.id);
if (msg.role === MessageRole.USER)
await this.generateResponseForMessage(messageIdForResponse);
} catch (error) {
console.error('Failed to edit message with branching:', error);
}

View File

@@ -39,6 +39,12 @@ import {
MULTIPLE_UNDERSCORE_REGEX,
MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY
} from '$lib/constants';
import { SvelteMap, SvelteSet } from 'svelte/reactivity';
export interface ConversationTreeItem {
conversation: DatabaseConversation;
depth: number;
}
class ConversationsStore {
/**
@@ -300,15 +306,45 @@ class ConversationsStore {
* Deletes a conversation and all its messages
* @param convId - The conversation ID to delete
*/
async deleteConversation(convId: string): Promise<void> {
async deleteConversation(convId: string, options?: { deleteWithForks?: boolean }): Promise<void> {
try {
await DatabaseService.deleteConversation(convId);
await DatabaseService.deleteConversation(convId, options);
this.conversations = this.conversations.filter((c) => c.id !== convId);
if (options?.deleteWithForks) {
// Collect all descendants recursively
const idsToRemove = new SvelteSet([convId]);
const queue = [convId];
while (queue.length > 0) {
const parentId = queue.pop()!;
for (const c of this.conversations) {
if (c.forkedFromConversationId === parentId && !idsToRemove.has(c.id)) {
idsToRemove.add(c.id);
queue.push(c.id);
}
}
}
this.conversations = this.conversations.filter((c) => !idsToRemove.has(c.id));
if (this.activeConversation?.id === convId) {
this.clearActiveConversation();
await goto(`?new_chat=true#/`);
if (this.activeConversation && idsToRemove.has(this.activeConversation.id)) {
this.clearActiveConversation();
await goto(`?new_chat=true#/`);
}
} else {
// Reparent direct children to deleted conv's parent (or promote to top-level)
const deletedConv = this.conversations.find((c) => c.id === convId);
const newParent = deletedConv?.forkedFromConversationId;
this.conversations = this.conversations
.filter((c) => c.id !== convId)
.map((c) =>
c.forkedFromConversationId === convId
? { ...c, forkedFromConversationId: newParent }
: c
);
if (this.activeConversation?.id === convId) {
this.clearActiveConversation();
await goto(`?new_chat=true#/`);
}
}
} catch (error) {
console.error('Failed to delete conversation:', error);
@@ -658,6 +694,42 @@ class ConversationsStore {
this.saveMcpDefaults();
}
/**
* Forks a conversation at a specific message, creating a new conversation
* containing messages from root up to the target message, then navigates to it.
*
* @param messageId - The message ID to fork at
* @param options - Fork options (name and whether to include attachments)
* @returns The new conversation ID, or null if fork failed
*/
async forkConversation(
messageId: string,
options: { name: string; includeAttachments: boolean }
): Promise<string | null> {
if (!this.activeConversation) return null;
try {
const newConv = await DatabaseService.forkConversation(
this.activeConversation.id,
messageId,
options
);
this.conversations = [newConv, ...this.conversations];
await goto(`#/chat/${newConv.id}`);
toast.success('Conversation forked');
return newConv.id;
} catch (error) {
console.error('Failed to fork conversation:', error);
toast.error('Failed to fork conversation');
return null;
}
}
/**
*
*
@@ -830,3 +902,53 @@ export const conversations = () => conversationsStore.conversations;
export const activeConversation = () => conversationsStore.activeConversation;
export const activeMessages = () => conversationsStore.activeMessages;
export const isConversationsInitialized = () => conversationsStore.isInitialized;
/**
* Builds a flat tree of conversations with depth levels for nested forks.
* Accepts a pre-filtered list so search filtering stays in the component.
*/
export function buildConversationTree(convs: DatabaseConversation[]): ConversationTreeItem[] {
const childrenByParent = new SvelteMap<string, DatabaseConversation[]>();
const forkIds = new SvelteSet<string>();
for (const conv of convs) {
if (conv.forkedFromConversationId) {
forkIds.add(conv.id);
const siblings = childrenByParent.get(conv.forkedFromConversationId) || [];
siblings.push(conv);
childrenByParent.set(conv.forkedFromConversationId, siblings);
}
}
const result: ConversationTreeItem[] = [];
const visited = new SvelteSet<string>();
function walk(conv: DatabaseConversation, depth: number) {
visited.add(conv.id);
result.push({ conversation: conv, depth });
const children = childrenByParent.get(conv.id);
if (children) {
children.sort((a, b) => b.lastModified - a.lastModified);
for (const child of children) {
walk(child, depth + 1);
}
}
}
const roots = convs.filter((c) => !forkIds.has(c.id));
for (const root of roots) {
walk(root, 0);
}
for (const conv of convs) {
if (!visited.has(conv.id)) {
walk(conv, 1);
}
}
return result;
}

View File

@@ -12,6 +12,7 @@ export interface DatabaseConversation {
lastModified: number;
name: string;
mcpServerOverrides?: McpServerOverride[];
forkedFromConversationId?: string;
}
export interface DatabaseMessageExtraAudioFile {

View File

@@ -4,7 +4,11 @@
import { browser } from '$app/environment';
import { page } from '$app/state';
import { untrack } from 'svelte';
import { ChatSidebar, DialogConversationTitleUpdate } from '$lib/components/app';
import {
ChatSidebar,
DialogConversationTitleUpdate,
DialogChatSettings
} from '$lib/components/app';
import { isLoading } from '$lib/stores/chat.svelte';
import { conversationsStore, activeMessages } from '$lib/stores/conversations.svelte';
import * as Sidebar from '$lib/components/ui/sidebar/index.js';
@@ -17,8 +21,10 @@
import { modelsStore } from '$lib/stores/models.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { TOOLTIP_DELAY_DURATION } from '$lib/constants';
import type { SettingsSectionTitle } from '$lib/constants';
import { KeyboardKey } from '$lib/enums';
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
import { setChatSettingsDialogContext } from '$lib/contexts';
let { children } = $props();
@@ -42,6 +48,16 @@
let titleUpdateNewTitle = $state('');
let titleUpdateResolve: ((value: boolean) => void) | null = null;
let chatSettingsDialogOpen = $state(false);
let chatSettingsDialogInitialSection = $state<SettingsSectionTitle | undefined>(undefined);
setChatSettingsDialogContext({
open: (initialSection?: SettingsSectionTitle) => {
chatSettingsDialogInitialSection = initialSection;
chatSettingsDialogOpen = true;
}
});
// Global keyboard shortcuts
function handleKeydown(event: KeyboardEvent) {
const isCtrlOrCmd = event.ctrlKey || event.metaKey;
@@ -213,6 +229,12 @@
<Toaster richColors />
<DialogChatSettings
open={chatSettingsDialogOpen}
onOpenChange={(open) => (chatSettingsDialogOpen = open)}
initialSection={chatSettingsDialogInitialSection}
/>
<DialogConversationTitleUpdate
bind:open={titleUpdateDialogOpen}
currentTitle={titleUpdateCurrentTitle}

View File

@@ -73,7 +73,7 @@
conversationsStore.conversations = mockConversations;
}, 0));
const searchTrigger = screen.getByText('Search conversations');
const searchTrigger = screen.getByText('Search');
userEvent.click(searchTrigger);
}}
>

View File

@@ -467,10 +467,6 @@ bool set_socket_opt_impl(socket_t sock, int level, int optname,
optlen) == 0;
}
bool set_socket_opt(socket_t sock, int level, int optname, int optval) {
return set_socket_opt_impl(sock, level, optname, &optval, sizeof(optval));
}
bool set_socket_opt_time(socket_t sock, int level, int optname,
time_t sec, time_t usec) {
#ifdef _WIN32
@@ -2218,7 +2214,7 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port,
#ifdef _WIN32
// Setting SO_REUSEADDR seems not to work well with AF_UNIX on windows, so
// remove the option.
detail::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0);
set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0);
#endif
bool dummy;
@@ -4373,6 +4369,7 @@ make_multipart_content_provider(const UploadFormDataItems &items,
struct MultipartState {
std::vector<std::string> owned;
std::vector<MultipartSegment> segs;
std::vector<char> buf = std::vector<char>(CPPHTTPLIB_SEND_BUFSIZ);
};
auto state = std::make_shared<MultipartState>();
state->owned = std::move(owned);
@@ -4381,19 +4378,49 @@ make_multipart_content_provider(const UploadFormDataItems &items,
state->segs = std::move(segs);
return [state](size_t offset, size_t length, DataSink &sink) -> bool {
// Buffer multiple small segments into fewer, larger writes to avoid
// excessive TCP packets when there are many form data items (#2410)
auto &buf = state->buf;
auto buf_size = buf.size();
size_t buf_len = 0;
size_t remaining = length;
// Find the first segment containing 'offset'
size_t pos = 0;
for (const auto &seg : state->segs) {
// Loop invariant: pos <= offset (proven by advancing pos only when
// offset - pos >= seg.size, i.e., the segment doesn't contain offset)
if (seg.size > 0 && offset - pos < seg.size) {
size_t seg_offset = offset - pos;
size_t available = seg.size - seg_offset;
size_t to_write = (std::min)(available, length);
return sink.write(seg.data + seg_offset, to_write);
}
size_t seg_idx = 0;
for (; seg_idx < state->segs.size(); seg_idx++) {
const auto &seg = state->segs[seg_idx];
if (seg.size > 0 && offset - pos < seg.size) { break; }
pos += seg.size;
}
return true; // past end (shouldn't be reached when content_length is exact)
size_t seg_offset = (seg_idx < state->segs.size()) ? offset - pos : 0;
for (; seg_idx < state->segs.size() && remaining > 0; seg_idx++) {
const auto &seg = state->segs[seg_idx];
size_t available = seg.size - seg_offset;
size_t to_copy = (std::min)(available, remaining);
const char *src = seg.data + seg_offset;
seg_offset = 0; // only the first segment has a non-zero offset
while (to_copy > 0) {
size_t space = buf_size - buf_len;
size_t chunk = (std::min)(to_copy, space);
std::memcpy(buf.data() + buf_len, src, chunk);
buf_len += chunk;
src += chunk;
to_copy -= chunk;
remaining -= chunk;
if (buf_len == buf_size) {
if (!sink.write(buf.data(), buf_len)) { return false; }
buf_len = 0;
}
}
}
if (buf_len > 0) { return sink.write(buf.data(), buf_len); }
return true;
};
}
@@ -5264,13 +5291,18 @@ bool setup_client_tls_session(const std::string &host, tls::ctx_t &ctx,
*/
void default_socket_options(socket_t sock) {
detail::set_socket_opt(sock, SOL_SOCKET,
set_socket_opt(sock, SOL_SOCKET,
#ifdef SO_REUSEPORT
SO_REUSEPORT,
SO_REUSEPORT,
#else
SO_REUSEADDR,
SO_REUSEADDR,
#endif
1);
1);
}
bool set_socket_opt(socket_t sock, int level, int optname, int optval) {
return detail::set_socket_opt_impl(sock, level, optname, &optval,
sizeof(optval));
}
std::string get_bearer_token_auth(const Request &req) {
@@ -7418,6 +7450,8 @@ bool Server::read_content_core(
return false;
}
req.body_consumed_ = true;
if (req.is_multipart_form_data()) {
if (!multipart_form_data_parser.is_valid()) {
res.status = StatusCode::BadRequest_400;
@@ -7688,9 +7722,7 @@ bool Server::listen_internal() {
detail::set_socket_opt_time(sock, SOL_SOCKET, SO_SNDTIMEO,
write_timeout_sec_, write_timeout_usec_);
if (tcp_nodelay_) {
detail::set_socket_opt(sock, IPPROTO_TCP, TCP_NODELAY, 1);
}
if (tcp_nodelay_) { set_socket_opt(sock, IPPROTO_TCP, TCP_NODELAY, 1); }
if (!task_queue->enqueue(
[this, sock]() { process_and_close_socket(sock); })) {
@@ -8036,8 +8068,19 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
return write_response(strm, close_connection, req, res);
}
// RFC 9112 §6.3: Reject requests with both a non-zero Content-Length and
// any Transfer-Encoding to prevent request smuggling. Content-Length: 0 is
// tolerated for compatibility with existing clients.
if (req.get_header_value_u64("Content-Length") > 0 &&
req.has_header("Transfer-Encoding")) {
connection_closed = true;
res.status = StatusCode::BadRequest_400;
return write_response(strm, close_connection, req, res);
}
// Check if the request URI doesn't exceed the limit
if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) {
connection_closed = true;
res.status = StatusCode::UriTooLong_414;
output_error_log(Error::ExceedUriMaxLength, &req);
return write_response(strm, close_connection, req, res);
@@ -8066,6 +8109,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
if (req.has_header("Accept")) {
const auto &accept_header = req.get_header_value("Accept");
if (!detail::parse_accept_header(accept_header, req.accept_content_types)) {
connection_closed = true;
res.status = StatusCode::BadRequest_400;
output_error_log(Error::HTTPParsing, &req);
return write_response(strm, close_connection, req, res);
@@ -8075,6 +8119,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
if (req.has_header("Range")) {
const auto &range_header_value = req.get_header_value("Range");
if (!detail::parse_range_header(range_header_value, req.ranges)) {
connection_closed = true;
res.status = StatusCode::RangeNotSatisfiable_416;
output_error_log(Error::InvalidRangeHeader, &req);
return write_response(strm, close_connection, req, res);
@@ -8202,6 +8247,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
}
}
#endif
auto ret = false;
if (routed) {
if (res.status == -1) {
res.status = req.ranges.empty() ? StatusCode::OK_200
@@ -8209,6 +8255,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
}
// Serve file content by using a content provider
auto file_open_error = false;
if (!res.file_content_path_.empty()) {
const auto &path = res.file_content_path_;
auto mm = std::make_shared<detail::mmap>(path.c_str());
@@ -8218,37 +8265,53 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
res.content_provider_ = nullptr;
res.status = StatusCode::NotFound_404;
output_error_log(Error::OpenFile, &req);
return write_response(strm, close_connection, req, res);
}
file_open_error = true;
} else {
auto content_type = res.file_content_content_type_;
if (content_type.empty()) {
content_type = detail::find_content_type(
path, file_extension_and_mimetype_map_, default_file_mimetype_);
}
auto content_type = res.file_content_content_type_;
if (content_type.empty()) {
content_type = detail::find_content_type(
path, file_extension_and_mimetype_map_, default_file_mimetype_);
res.set_content_provider(
mm->size(), content_type,
[mm](size_t offset, size_t length, DataSink &sink) -> bool {
sink.write(mm->data() + offset, length);
return true;
});
}
res.set_content_provider(
mm->size(), content_type,
[mm](size_t offset, size_t length, DataSink &sink) -> bool {
sink.write(mm->data() + offset, length);
return true;
});
}
if (detail::range_error(req, res)) {
if (file_open_error) {
ret = write_response(strm, close_connection, req, res);
} else if (detail::range_error(req, res)) {
res.body.clear();
res.content_length_ = 0;
res.content_provider_ = nullptr;
res.status = StatusCode::RangeNotSatisfiable_416;
return write_response(strm, close_connection, req, res);
ret = write_response(strm, close_connection, req, res);
} else {
ret = write_response_with_content(strm, close_connection, req, res);
}
return write_response_with_content(strm, close_connection, req, res);
} else {
if (res.status == -1) { res.status = StatusCode::NotFound_404; }
return write_response(strm, close_connection, req, res);
ret = write_response(strm, close_connection, req, res);
}
// Drain any unconsumed request body to prevent request smuggling on
// keep-alive connections.
if (!req.body_consumed_ && detail::expect_content(req)) {
int drain_status = 200; // required by read_content signature
if (!detail::read_content(
strm, req, payload_max_length_, drain_status, nullptr,
[](const char *, size_t, size_t, size_t) { return true; }, false)) {
// Body exceeds payload limit or read error — close the connection
// to prevent leftover bytes from being misinterpreted.
connection_closed = true;
}
}
return ret;
}
bool Server::is_valid() const { return true; }

View File

@@ -8,8 +8,8 @@
#ifndef CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_VERSION "0.39.0"
#define CPPHTTPLIB_VERSION_NUM "0x002700"
#define CPPHTTPLIB_VERSION "0.40.0"
#define CPPHTTPLIB_VERSION_NUM "0x002800"
#ifdef _WIN32
#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00
@@ -1266,6 +1266,7 @@ struct Request {
bool is_multipart_form_data() const;
// private members...
bool body_consumed_ = false;
size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT;
size_t content_length_ = 0;
ContentProvider content_provider_;
@@ -1475,6 +1476,8 @@ using SocketOptions = std::function<void(socket_t sock)>;
void default_socket_options(socket_t sock);
bool set_socket_opt(socket_t sock, int level, int optname, int optval);
const char *status_message(int status);
std::string to_string(Error error);
@@ -1564,6 +1567,13 @@ ssize_t write_headers(Stream &strm, const Headers &headers);
bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec,
time_t usec);
size_t get_multipart_content_length(const UploadFormDataItems &items,
const std::string &boundary);
ContentProvider
make_multipart_content_provider(const UploadFormDataItems &items,
const std::string &boundary);
} // namespace detail
class Server {