Compare commits

...

7 Commits
b8776 ... b8783

Author SHA1 Message Date
Aldehir Rojas
e21cdc11a0 common/gemma4 : handle parsing edge cases (#21760) 2026-04-13 18:18:18 -05:00
Xuan-Son Nguyen
e974923698 docs: listing qwen3-asr and qwen3-omni as supported (#21857)
* docs: listing qwen3-asr and qwen3-omni as supported

* nits
2026-04-13 22:28:17 +02:00
Piotr Wilkin (ilintar)
1c0d9081fd chat: dedicated DeepSeek v3.2 parser + "official" template (#21785) 2026-04-13 22:23:53 +02:00
Christian Kastner
a8bad3842e ci: Also exempt 'security' tag from auto-close (#21844) 2026-04-14 01:18:44 +08:00
Ruben Ortlam
75f3bc94e6 vulkan: Flash Attention DP4A shader for quantized KV cache (#20797)
* use integer dot product for quantized KV flash attention

* small improvements

* fix SHMEM_STAGING indexing

* add missing KV type quants

* fixes

* add supported quants to FA tests

* readd fast paths for <8bit quants

* fix mmq gate and shmem checks
2026-04-13 14:21:31 +02:00
Adrien Gallouët
aa00911d12 common : add download cancellation and temp file cleanup (#21813)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-04-13 11:18:23 +02:00
Gaspard Petit
ce8fd4b1a6 server: Expose build_info in router mode (#21835) 2026-04-13 11:14:42 +02:00
21 changed files with 1137 additions and 37 deletions

View File

@@ -17,7 +17,7 @@ jobs:
steps:
- uses: actions/stale@v10
with:
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap"
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap,security"
days-before-issue-stale: 30
days-before-issue-close: 14
stale-issue-label: "stale"

View File

@@ -1091,6 +1091,14 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
common_chat_params data;
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
if (inputs.add_generation_prompt && string_ends_with(data.prompt, "<turn|>\n")) {
// This may happen if the model generates content + tool_call, the
// template does not add the model's next turn and confuses the model
// from emitting its proper reasoning token sequence.
data.prompt += "<|turn>model\n";
}
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
data.supports_thinking = true;
data.thinking_start_tag = "<|channel>thought";
@@ -1118,7 +1126,8 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
p.rule("thought", p.content(p.literal("<|channel>thought") + p.space() + p.until("<channel|>") + p.literal("<channel|>")));
}
auto thought = (p.peek(p.literal("<|channel>")) + p.ref("thought")) | p.negate(p.literal("<|channel>"));
auto consume_empty_channels = p.gbnf(p.zero_or_more(p.literal("<|channel>") + p.negate(p.literal("thought"))), "");
auto thought = (p.peek(p.literal("<|channel>")) + consume_empty_channels + p.ref("thought")) | p.negate(p.literal("<|channel>"));
if (has_response_format) {
auto response_format = p.literal("```json") <<
@@ -1182,12 +1191,16 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
/* max = */ inputs.parallel_tool_calls ? -1 : 1
));
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<|tool_call>"})));
auto scan_to_toolcall = p.rule("scan-to-toolcall", p.until("<|tool_call>"));
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<channel|>", "<|tool_call>"})));
auto message = p.rule("message", thought + content);
return start + p.zero_or_more(message) + tool_call;
return start + p.zero_or_more(message) + scan_to_toolcall + tool_call;
}
auto content = p.rule("content", p.content(p.until("<|channel>")));
// Gemma 4 may emit an extra <|channel>thought\n<channel|> at the end of the content. It may
// also emit a single trailing <channel|> token. Consume all complete reasoning blocks and
// then stop at the first unmatched <channel|> token.
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<channel|>"})));
auto message = p.rule("message", thought + content);
return start + p.one_or_more(message);
});
@@ -1656,6 +1669,173 @@ static common_chat_params common_chat_params_init_gigachat_v3(
return data;
}
static common_chat_params common_chat_params_init_deepseek_v3_2(const common_chat_template & tmpl,
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.thinking_start_tag = "<think>";
data.thinking_end_tag = "</think>";
data.preserved_tokens = {
"DSML",
"<think>",
"</think>",
};
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
const std::string DSML = "DSML";
const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>";
const std::string FC_START = "<" + DSML + "function_calls>";
const std::string FC_END = "</" + DSML + "function_calls>";
const std::string INVOKE_START = "<" + DSML + "invoke";
const std::string INVOKE_END = "</" + DSML + "invoke>";
const std::string PARAM_START = "<" + DSML + "parameter";
const std::string PARAM_END = "</" + DSML + "parameter>";
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
auto end = p.end();
auto reasoning = p.eps();
if (extract_reasoning && inputs.enable_thinking) {
reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END);
} else if (extract_reasoning) {
// Thinking disabled but reasoning extraction requested: the generation prompt
// contains an empty <think></think> pair that must still be consumed.
reasoning = p.optional(p.literal(THINK_START) + p.until(THINK_END) + p.literal(THINK_END));
}
if (has_response_format) {
auto response_format = p.rule("response-format",
p.literal("```json") + p.space() +
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)) +
p.space() + p.literal("```"));
return generation_prompt + reasoning + response_format + end;
}
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return generation_prompt + reasoning + p.content(p.rest()) + end;
}
auto tool_choice = p.choice();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto params = function.contains("parameters") ? function.at("parameters") : json::object();
const auto & props = params.contains("properties") ? params.at("properties") : json::object();
std::set<std::string> required;
if (params.contains("required")) {
params.at("required").get_to(required);
}
auto schema_info = common_schema_info();
schema_info.resolve_refs(params);
std::vector<common_peg_parser> required_parsers;
std::vector<common_peg_parser> optional_parsers;
for (const auto & [param_name, param_schema] : props.items()) {
bool is_required = required.find(param_name) != required.end();
bool is_string = schema_info.resolves_to_string(param_schema);
auto arg = p.tool_arg(
p.tool_arg_open(
p.literal(PARAM_START + " name=\"") +
p.tool_arg_name(p.literal(param_name)) +
p.literal("\" string=\"" + std::string(is_string ? "true" : "false") + "\">")) +
(is_string
? p.tool_arg_string_value(p.until(PARAM_END))
: p.tool_arg_json_value(p.schema(p.json(),
"tool-" + name + "-arg-" + param_name + "-schema",
param_schema, false))) +
p.tool_arg_close(p.literal(PARAM_END)));
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
if (is_required) {
required_parsers.push_back(named_arg);
} else {
optional_parsers.push_back(named_arg);
}
}
common_peg_parser args_seq = p.eps();
for (size_t i = 0; i < required_parsers.size(); i++) {
if (i > 0) {
args_seq = args_seq + p.space();
}
args_seq = args_seq + required_parsers[i];
}
if (!optional_parsers.empty()) {
common_peg_parser any_opt = p.choice();
for (const auto & opt : optional_parsers) {
any_opt |= opt;
}
args_seq = args_seq + p.repeat(p.space() + any_opt, 0, -1);
}
common_peg_parser invoke_body = args_seq;
auto func_parser = p.tool(
p.tool_open(p.literal(INVOKE_START + " name=\"") +
p.tool_name(p.literal(name)) + p.literal("\">\n")) +
invoke_body + p.space() +
p.tool_close(p.literal(INVOKE_END)));
tool_choice |= p.rule("tool-" + name, func_parser);
});
auto require_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
common_peg_parser tool_calls = p.eps();
if (inputs.parallel_tool_calls) {
tool_calls = p.trigger_rule("tool-call",
p.literal(FC_START) + p.space() + tool_choice +
p.zero_or_more(p.space() + tool_choice) + p.space() + p.literal(FC_END));
} else {
tool_calls = p.trigger_rule("tool-call",
p.literal(FC_START) + p.space() + tool_choice + p.space() + p.literal(FC_END));
}
if (!require_tools) {
tool_calls = p.optional(tool_calls);
}
auto content_before_tools = p.content(p.until(FC_START));
return generation_prompt + reasoning + content_before_tools + tool_calls + end;
});
data.parser = parser.save();
if (include_grammar) {
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
builder.resolve_refs(schema);
});
if (has_response_format) {
auto schema = inputs.json_schema;
builder.resolve_refs(schema);
}
parser.build_grammar(builder, data.grammar_lazy);
});
data.grammar_triggers = {
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, FC_START },
};
}
return data;
}
namespace workaround {
static void map_developer_role_to_system(json & messages) {
@@ -1927,6 +2107,15 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
return common_chat_params_init_gigachat_v3(tmpl, params);
}
// DeepSeek V3.2 format detection: template defines dsml_token and uses it for tool calls.
// The template source contains the token as a variable assignment, not as a literal in markup.
if (src.find("dsml_token") != std::string::npos &&
src.find("function_calls") != std::string::npos &&
src.find("DSML") != std::string::npos) {
LOG_DBG("Using specialized template: DeepSeek V3.2\n");
return common_chat_params_init_deepseek_v3_2(tmpl, params);
}
// Gemma4 format detection
if (src.find("'<|tool_call>call:'") != std::string::npos) {
if (src.find("{#- OpenAI Chat Completions:") == std::string::npos) {

View File

@@ -258,6 +258,9 @@ static bool common_pull_file(httplib::Client & cli,
if (progress_step >= p.total / 1000 || p.downloaded == p.total) {
if (callback) {
callback->on_update(p);
if (callback->is_cancelled()) {
return false;
}
}
progress_step = 0;
}
@@ -373,6 +376,9 @@ static int common_download_file_single_online(const std::string & url,
}
for (int i = 0; i < max_attempts; ++i) {
if (opts.callback && opts.callback->is_cancelled()) {
break;
}
if (i) {
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
std::this_thread::sleep_for(std::chrono::seconds(delay));
@@ -412,6 +418,12 @@ static int common_download_file_single_online(const std::string & url,
if (opts.callback) {
opts.callback->on_done(p, success);
}
if (opts.callback && opts.callback->is_cancelled() &&
std::filesystem::exists(path_temporary)) {
if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, path_temporary.c_str());
}
}
if (!success) {
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
return -1; // max attempts reached

View File

@@ -21,6 +21,7 @@ public:
virtual void on_start(const common_download_progress & p) = 0;
virtual void on_update(const common_download_progress & p) = 0;
virtual void on_done(const common_download_progress & p, bool ok) = 0;
virtual bool is_cancelled() const { return false; }
};
struct common_remote_params {

View File

@@ -890,6 +890,10 @@ struct parser_executor {
}
return result;
}
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
return arena.parse(p.child, ctx, start_pos);
}
};
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
@@ -957,7 +961,8 @@ void common_peg_arena::resolve_refs() {
std::is_same_v<T, common_peg_and_parser> ||
std::is_same_v<T, common_peg_not_parser> ||
std::is_same_v<T, common_peg_tag_parser> ||
std::is_same_v<T, common_peg_atomic_parser>) {
std::is_same_v<T, common_peg_atomic_parser> ||
std::is_same_v<T, common_peg_gbnf_parser>) {
p.child = resolve_ref(p.child);
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
p.child = resolve_ref(p.child);
@@ -1036,6 +1041,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
return "Not(" + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
return "Atomic(" + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
return "Any";
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
@@ -1565,6 +1572,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
std::is_same_v<T, common_peg_not_parser> ||
std::is_same_v<T, common_peg_tag_parser> ||
std::is_same_v<T, common_peg_atomic_parser> ||
std::is_same_v<T, common_peg_gbnf_parser> ||
std::is_same_v<T, common_peg_schema_parser>) {
visit(p.child);
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
@@ -1651,10 +1659,13 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
} else if constexpr (std::is_same_v<T, common_peg_sequence_parser>) {
std::string s;
for (const auto & child : p.children) {
auto child_gbnf = to_gbnf(child);
if (child_gbnf.empty()) {
continue;
}
if (!s.empty()) {
s += " ";
}
auto child_gbnf = to_gbnf(child);
const auto & child_parser = effective_parser(child);
if (std::holds_alternative<common_peg_choice_parser>(child_parser) ||
std::holds_alternative<common_peg_sequence_parser>(child_parser)) {
@@ -1754,6 +1765,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
return to_gbnf(p.child);
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
return to_gbnf(p.child);
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return p.grammar;
} else {
static_assert(is_always_false_v<T>);
}
@@ -1888,6 +1901,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
{"child", p.child},
{"tag", p.tag}
};
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
}
}, variant);
}
@@ -2050,6 +2065,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
};
}
if (type == "gbnf") {
if (!j.contains("child") || !j.contains("grammar")) {
throw std::runtime_error("gbnf parser missing required fields");
}
return common_peg_gbnf_parser{
j["child"].get<common_peg_parser_id>(),
j["grammar"].get<std::string>(),
};
}
throw std::runtime_error("Unknown parser type: " + type);
}

View File

@@ -270,6 +270,11 @@ struct common_peg_tag_parser {
std::string tag;
};
struct common_peg_gbnf_parser {
common_peg_parser_id child;
std::string grammar;
};
// Variant holding all parser types
using common_peg_parser_variant = std::variant<
common_peg_epsilon_parser,
@@ -290,7 +295,8 @@ using common_peg_parser_variant = std::variant<
common_peg_rule_parser,
common_peg_ref_parser,
common_peg_atomic_parser,
common_peg_tag_parser
common_peg_tag_parser,
common_peg_gbnf_parser
>;
class common_peg_arena {
@@ -504,6 +510,10 @@ class common_peg_parser_builder {
// Unlike rules, you can tag multiple nodes with the same tag.
common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); }
// Wraps a child parser but emits a custom GBNF grammar string instead of
// the child's grammar. Parsing delegates entirely to the child.
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
void set_root(const common_peg_parser & p);
common_peg_arena build();

View File

@@ -114,6 +114,10 @@ NOTE: some models may require large context window, for example: `-c 8192`
# Mistral's Voxtral
(tool_name) -hf ggml-org/Voxtral-Mini-3B-2507-GGUF
# Qwen3-ASR
(tool_name) -hf ggml-org/Qwen3-ASR-0.6B-GGUF
(tool_name) -hf ggml-org/Qwen3-ASR-1.7B-GGUF
```
**Mixed modalities**:
@@ -124,6 +128,11 @@ NOTE: some models may require large context window, for example: `-c 8192`
(tool_name) -hf ggml-org/Qwen2.5-Omni-3B-GGUF
(tool_name) -hf ggml-org/Qwen2.5-Omni-7B-GGUF
# Qwen3 Omni
# Capabilities: audio input, vision input
(tool_name) -hf ggml-org/Qwen3-Omni-30B-A3B-Instruct-GGUF
(tool_name) -hf ggml-org/Qwen3-Omni-30B-A3B-Thinking-GGUF
# Gemma 4
# Capabilities: audio input, vision input
(tool_name) -hf ggml-org/gemma-4-E2B-it-GGUF

View File

@@ -2858,11 +2858,10 @@ struct vk_fa_tuning_params {
}
};
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type);
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
GGML_UNUSED(kv_type);
vk_fa_tuning_params result{};
result.path = FA_SCALAR;
@@ -2914,7 +2913,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) {
result.block_rows /= 2;
}
@@ -3445,21 +3444,47 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->fp16) {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product && device->subgroup_clustered) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8)
} else
#endif
{
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
}
} else {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product && device->subgroup_clustered) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8)
} else
#endif
{
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
}
}
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {
@@ -8780,7 +8805,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
GGML_UNUSED(f32acc);
// Needs to be kept up to date on shader changes
const uint32_t wg_size = params.workgroup_size;
@@ -8789,21 +8814,51 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
const bool mmq = device->integer_dot_product && device->subgroup_clustered &&
(kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 ||
kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 ||
kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL);
// tmpsh is overestimated slightly
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * float_type_size;
const uint32_t masksh = Bc * (Br + 1) * float_type_size;
const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
uint32_t Qf, kvsh, kblocksh_size;
if (mmq) {
// block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds
const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size;
Qf = Br * (hsk / 32) * block_b_size;
const uint32_t D = std::max(hsk, hsv);
const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
// kvsh uses D = HSV (K goes through kblocksh instead)
kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
// block_a_cache size depends on quant type
uint32_t block_a_size;
switch (kv_type) {
case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break;
case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break;
case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break;
case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break;
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break;
default: block_a_size = 0; break;
}
kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
} else {
Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
const uint32_t D = std::max(hsk, hsv);
kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
kblocksh_size = 0;
}
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported);
return supported;
}

View File

@@ -10,6 +10,13 @@
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
#endif
#ifdef MMQ
#extension GL_EXT_integer_dot_product : require
#extension GL_KHR_shader_subgroup_clustered : require
#include "mul_mmq_shmem_types.glsl"
#endif
#extension GL_KHR_shader_subgroup_shuffle : enable
#extension GL_KHR_shader_subgroup_vote : enable
@@ -41,15 +48,34 @@ shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
const uint32_t masksh_stride = Br + 1;
shared FLOAT_TYPE masksh[Bc * masksh_stride];
#ifndef MMQ
const uint32_t qf_stride = HSK / 4 + 1;
shared FLOAT_TYPEV4 Qf[Br * qf_stride];
#else
const uint32_t qf_stride = HSK / 32;
shared block_b_cache Qf[Br * qf_stride];
#endif
#ifndef MMQ
const uint32_t D = HSK > HSV ? HSK : HSV;
#else
const uint32_t D = HSV;
#endif
const uint32_t kvsh_stride = D / 4 + 1;
shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
#ifdef MMQ
shared block_a_cache kblocksh[SHMEM_STAGING != 0 ? Bc * qf_stride : 1];
#endif
shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
#ifdef MMQ
#include "flash_attn_mmq_funcs.glsl"
#endif
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
@@ -82,10 +108,39 @@ void main() {
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
const bool is_in_bounds = r < Br && d < HSK / 4 && i * Br + r < N;
#ifndef MMQ
if (is_in_bounds) {
Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
#else
const uint buf_ib = r * qf_stride + d / 8;
const uint buf_iqs = d % 8;
FLOAT_TYPEV4 vals = is_in_bounds ? FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale) : FLOAT_TYPEV4(0.0f);
const FLOAT_TYPEV4 abs_vals = abs(vals);
const FLOAT_TYPE thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
const FLOAT_TYPE amax = subgroupClusteredMax(thread_max, 8);
const FLOAT_TYPE qd = amax / FLOAT_TYPE(127.0);
const FLOAT_TYPE qd_inv = qd != FLOAT_TYPE(0.0) ? FLOAT_TYPE(1.0) / qd : FLOAT_TYPE(0.0);
vals = round(vals * qd_inv);
Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));
#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
}
#else // Q4_0, Q4_1, Q5_0, Q5_1
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
}
#endif
#endif
}
barrier();
@@ -195,6 +250,7 @@ void main() {
if (SHMEM_STAGING != 0) {
barrier();
#ifndef MMQ
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
@@ -214,9 +270,29 @@ void main() {
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
#else // MMQ
const uint ints_per_block = 8 / QUANT_R_MMQ;
const uint quant_iters = Bc * HSK / 32 * ints_per_block;
[[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
const uint32_t iqs = (idx + tid) % ints_per_block;
const uint32_t ib = (idx + tid) / ints_per_block;
const uint32_t c = ib / (HSK / 32);
const uint32_t block = ib % (HSK / 32);
if (idx + gl_WorkGroupSize.x <= quant_iters || c < Bc) {
const uint buf_ib = c * qf_stride + block;
if (!KV_bounds_check || j * Bc + c < KV) {
const uint global_ib = (j * Bc + c) * k_stride + block;
k_block_to_shmem(buf_ib, global_ib, iqs, k_offset);
} else {
k_block_to_shmem_zero(buf_ib, iqs);
}
}
}
#endif // MMQ
barrier();
}
#ifndef MMQ
// More d iterations means Q register caching becomes relevant
// Few iterations means the additional registers needed are worse than the speed-up from caching
if (HSK_per_thread / 4 > 4) {
@@ -275,6 +351,110 @@ void main() {
}
}
}
#else // MMQ
const uint hsk4 = HSK_per_thread / 4;
const uint d_per_step = (hsk4 % 8 == 0) ? 8 :
(hsk4 % 4 == 0) ? 4 :
(hsk4 % 2 == 0) ? 2 : 1;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d_block = 0; d_block < HSK_per_thread / 4; d_block += d_per_step) {
int32_t k_quants[d_per_step];
ACC_TYPEV2 k_dm;
if (SHMEM_STAGING != 0) {
const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
#if QUANT_AUXF == 1
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0);
#else
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm);
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
if (d_per_step == 8) {
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
uint vui = kblocksh[buf_ib].qs[d];
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
#endif
}
} else
#endif
{
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d);
}
}
} else {
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
const uint ib = coord / BLOCK_SIZE;
const uint iqs = (coord % BLOCK_SIZE);
#if QUANT_AUXF == 1
k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0);
#else
k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset));
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
if (d_per_step == 8) {
#if defined(DATA_A_Q5_0)
uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0],
k_packed.k_data_packed16[k_offset + ib].qh[1]));
#elif defined(DATA_A_Q5_1)
uint qh = k_packed.k_data_packed16[k_offset + ib].qh;
#endif
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
#if defined(A_TYPE_PACKED32)
uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d];
#else
uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0],
k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]));
#endif
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint qh_lo = (qh >> (d * 4)) & 0xF;
uint qh_hi = (qh >> (d * 4 + 16)) & 0xF;
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
#endif
}
} else
#endif
{
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
}
}
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint qib = tile_row(r) * qf_stride + (d_tid * (HSK_per_thread / 4) + d_block) / 8;
const uint qiqs = (d_tid * (HSK_per_thread / 4) + d_block) % 8;
int32_t acc = 0;
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
acc += dotPacked4x8EXT(Qf[qib].qs[qiqs + d], k_quants[d]);
}
Sf[r][c] += ACC_TYPE(acc) * ACC_TYPE(Qf[qib].ds.x) * k_dm.x;
if ((d_tid * (HSK_per_thread / 4) + d_block) % 8 == 0) {
Sf[r][c] += k_dot_correction(qib, k_dm);
}
}
}
}
#endif // MMQ
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
// Compute sum across the D_split

View File

@@ -89,6 +89,11 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32;
layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32;
#endif
#ifndef BLOCK_SIZE
#define BLOCK_SIZE 1
#endif

View File

@@ -0,0 +1,149 @@
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
#ifdef DATA_A_Q4_0
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
#else
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
#endif
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
return int32_t(vui & 0x0F0F0F0F);
}
#endif
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
#ifdef DATA_A_Q5_0
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0],
k_packed.k_data_packed16[a_offset + ib].qh[1]));
#else
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
#endif
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
uint qh_bits = (qh >> iqs) & 0xF;
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
}
#endif
#if defined(DATA_A_Q8_0)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1]));
}
#endif
#if defined(DATA_A_IQ4_NL)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
u8vec4 idx = unpack8(vui & 0x0F0F0F0F);
return pack32(i8vec4(kvalues_iq4nl_const[idx.x],
kvalues_iq4nl_const[idx.y],
kvalues_iq4nl_const[idx.z],
kvalues_iq4nl_const[idx.w]));
}
#endif
#if QUANT_AUXF == 1
FLOAT_TYPE get_k_d(uint ib, uint a_offset) {
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d);
}
#else
FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) {
return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm);
}
#endif
void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) {
#if defined(DATA_A_Q4_0)
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
#elif defined(DATA_A_Q4_1)
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
#elif defined(DATA_A_Q5_0)
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0],
k_packed.k_data_packed16[a_offset + global_ib].qh[1]));
}
#elif defined(DATA_A_Q5_1)
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
if (iqs == 0) {
kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh;
}
#elif defined(DATA_A_Q8_0)
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
#elif defined(DATA_A_IQ4_NL)
const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y],
kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w]));
kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y],
kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w]));
#endif
if (iqs == 0) {
#if QUANT_AUXF == 1
kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d);
#else
kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm);
#endif
}
}
int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) {
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
uint sub = pos % 4;
uint shift = ((pos % 8) >= 4) ? 4 : 0;
return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint sub = pos % 4;
uint shift = ((pos % 8) >= 4) ? 4 : 0;
int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF;
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
return kblocksh[buf_ib].qs[pos];
#endif
}
ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) {
#if defined(DATA_A_Q4_0)
return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
#elif defined(DATA_A_Q5_0)
return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
#else
return ACC_TYPE(0.0);
#endif
}
void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) {
kblocksh[buf_ib].qs[iqs] = 0;
#if defined(DATA_A_IQ4_NL)
kblocksh[buf_ib].qs[iqs + 4] = 0;
#endif
if (iqs == 0) {
#if QUANT_AUXF == 1
kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f);
#else
kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f);
#endif
}
}

View File

@@ -32,6 +32,12 @@ struct block_a_cache {
int32_t qs[32/4];
FLOAT_TYPE dm;
};
#elif defined(DATA_A_IQ4_NL)
#define QUANT_R_MMQ 2
struct block_a_cache {
int32_t qs[8];
FLOAT_TYPE dm;
};
#elif defined(DATA_A_MXFP4)
#define QUANT_R_MMQ 2
struct block_a_cache {

View File

@@ -1692,6 +1692,7 @@ struct block_iq4_nl_packed16
#if defined(DATA_A_IQ4_NL)
#define QUANT_K QUANT_K_IQ4_NL
#define QUANT_R QUANT_R_IQ4_NL
#define QUANT_AUXF 1
#define A_TYPE block_iq4_nl
#define A_TYPE_PACKED16 block_iq4_nl_packed16
#endif

View File

@@ -406,8 +406,8 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
}
static std::vector<std::future<void>> compiles;
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, const std::string& suffix = "") {
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")) + suffix;
std::string out_path = join_paths(output_dir, name + ".spv");
if (input_filepath == "") {
@@ -625,15 +625,16 @@ void process_shaders() {
for (const bool& fp16 : {false, true}) {
std::map<std::string, std::string> base_dict;
if (fp16) {
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV2", "f16vec2"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
} else {
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"FLOAT_TYPEV4", "vec4"}};
}
// flash attention
for (const bool& f16acc : {false, true}) {
std::map<std::string, std::string> fa_base_dict = base_dict;
fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float";
fa_base_dict["ACC_TYPEV2"] = fp16 && f16acc ? "f16vec2" : "vec2";
fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4";
if (fp16 && f16acc) {
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
@@ -672,6 +673,12 @@ void process_shaders() {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (tname != "f32") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8");
}
#endif
}
}
}

View File

@@ -0,0 +1,141 @@
{%- if not add_generation_prompt is defined -%}
{%- set add_generation_prompt = false -%}
{%- endif -%}
{%- if not thinking is defined -%}
{%- if enable_thinking is defined -%}
{%- set thinking = enable_thinking -%}
{%- else -%}
{%- set thinking = false -%}
{%- endif -%}
{%- endif -%}
{%- set dsml_token = 'DSML' -%}
{%- set thinking_start_token = '<think>' -%}
{%- set thinking_end_token = '</think>' -%}
{%- set tools_header = '## Tools\n\nYou have access to a set of tools you can use to answer the user\'s question.\nYou can invoke functions by writing a "<' + dsml_token + 'function_calls>" block like the following as part of your reply to the user:\n<' + dsml_token + 'function_calls>\n<' + dsml_token + 'invoke name="$FUNCTION_NAME">\n<' + dsml_token + 'parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</' + dsml_token + 'parameter>\n...\n</' + dsml_token + 'invoke>\n<' + dsml_token + 'invoke name="$FUNCTION_NAME2">\n...\n</' + dsml_token + 'invoke>\n</' + dsml_token + 'function_calls>\n\nString and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects).\n\nIf the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example:\n\n<' + dsml_token + 'function_calls>\n...\n</' + dsml_token + 'function_calls>\n\n<function_results>\n...\n</function_results>\n\n' + thinking_start_token + '...thinking about results' + thinking_end_token + '\n\nHere are the functions available in JSONSchema format:\n<functions>\n' -%}
{%- set tools_footer = '</functions>\n' -%}
{%- set ns = namespace(system_prompt='', is_first_sp=true) -%}
{%- for message in messages -%}
{%- if message['role'] == 'system' -%}
{%- if ns.is_first_sp -%}
{%- set ns.system_prompt = ns.system_prompt + (message['content'] or '') -%}
{%- set ns.is_first_sp = false -%}
{%- else -%}
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + (message['content'] or '') -%}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if tools is defined and tools -%}
{%- set ts = namespace(schemas='') -%}
{%- for tool in tools -%}
{%- if tool['type'] == 'function' -%}
{%- set ts.schemas = ts.schemas + (tool['function'] | tojson) + '\n' -%}
{%- endif -%}
{%- endfor -%}
{%- if ns.system_prompt -%}
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + tools_header + ts.schemas + tools_footer -%}
{%- else -%}
{%- set ns.system_prompt = tools_header + ts.schemas + tools_footer -%}
{%- endif -%}
{%- endif -%}
{{- bos_token -}}
{{- ns.system_prompt -}}
{%- set last_user_idx = namespace(value=-1) -%}
{%- for message in messages -%}
{%- if message['role'] == 'user' or message['role'] == 'developer' -%}
{%- set last_user_idx.value = loop.index0 -%}
{%- endif -%}
{%- endfor -%}
{%- set state = namespace(pending_asst_marker=false, pending_tool_marker=false) -%}
{%- for message in messages -%}
{%- if message['role'] == 'user' -%}
{{- '<User>' + (message['content'] or '') -}}
{%- set state.pending_asst_marker = true -%}
{%- set state.pending_tool_marker = false -%}
{%- elif message['role'] == 'assistant' -%}
{%- set is_after_last_user = loop.index0 > last_user_idx.value -%}
{%- if state.pending_asst_marker -%}
{{- '<Assistant>' -}}
{%- if is_after_last_user and thinking -%}
{{- thinking_start_token -}}
{%- if message['reasoning_content'] is defined and message['reasoning_content'] -%}
{{- message['reasoning_content'] -}}
{%- endif -%}
{{- thinking_end_token -}}
{%- else -%}
{{- thinking_end_token -}}
{%- endif -%}
{%- elif state.pending_tool_marker -%}
{%- if is_after_last_user and thinking -%}
{{- '\n\n' + thinking_start_token -}}
{%- if message['reasoning_content'] is defined and message['reasoning_content'] -%}
{{- message['reasoning_content'] -}}
{%- endif -%}
{{- thinking_end_token -}}
{%- else -%}
{{- '\n\n' + thinking_end_token -}}
{%- endif -%}
{%- endif -%}
{%- set state.pending_asst_marker = false -%}
{%- set state.pending_tool_marker = false -%}
{%- if message['content'] is defined and message['content'] -%}
{{- message['content'] -}}
{%- endif -%}
{%- if message['tool_calls'] -%}
{{- '\n\n<' + dsml_token + 'function_calls>\n' -}}
{%- for tool in message['tool_calls'] -%}
{%- set func = tool['function'] -%}
{{- '<' + dsml_token + 'invoke name="' + func['name'] + '">\n' -}}
{%- set args = func['arguments'] -%}
{%- if args is string -%}
{%- set args = args | from_json -%}
{%- endif -%}
{%- for key, val in args.items() -%}
{%- if val is string -%}
{{- '<' + dsml_token + 'parameter name="' + key + '" string="true">' + val + '</' + dsml_token + 'parameter>\n' -}}
{%- else -%}
{{- '<' + dsml_token + 'parameter name="' + key + '" string="false">' + (val | tojson) + '</' + dsml_token + 'parameter>\n' -}}
{%- endif -%}
{%- endfor -%}
{{- '</' + dsml_token + 'invoke>\n' -}}
{%- endfor -%}
{{- '</' + dsml_token + 'function_calls>' -}}
{%- endif -%}
{{- '<end▁of▁sentence>' -}}
{%- elif message['role'] == 'tool' -%}
{%- set outer_index = loop.index0 -%}
{%- set assistant_idx = namespace(value=-1) -%}
{%- for prev_msg in messages -%}
{%- if prev_msg['role'] == 'assistant' and prev_msg['tool_calls'] and loop.index0 < outer_index -%}
{%- set assistant_idx.value = loop.index0 -%}
{%- endif -%}
{%- endfor -%}
{%- set call_order = outer_index - assistant_idx.value -%}
{%- set assistant_msg = messages[assistant_idx.value] -%}
{%- set tool_call_count = assistant_msg['tool_calls'] | length -%}
{%- if call_order == 1 -%}
{{- '\n\n<function_results>' -}}
{%- endif -%}
{{- '\n<result>' + (message['content'] or '') + '</result>' -}}
{%- if call_order == tool_call_count -%}
{{- '\n</function_results>' -}}
{%- set state.pending_asst_marker = false -%}
{%- set state.pending_tool_marker = true -%}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{%- if state.pending_asst_marker -%}
{{- '<Assistant>' -}}
{%- if thinking -%}
{{- thinking_start_token -}}
{%- else -%}
{{- thinking_start_token + thinking_end_token -}}
{%- endif -%}
{%- elif state.pending_tool_marker -%}
{%- if thinking -%}
{{- '\n\n' + thinking_start_token -}}
{%- else -%}
{{- '\n\n' + thinking_start_token + thinking_end_token -}}
{%- endif -%}
{%- endif -%}
{%- endif -%}

View File

@@ -258,6 +258,66 @@ void test_gbnf_generation(testing &t) {
)""", gbnf);
});
t.test("silent parser emits nothing in gbnf", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.literal("hello") + p.gbnf(p.literal("world"), "");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= "hello"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("silent choice inside sequence emits nothing", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.literal("a") + p.gbnf(p.literal("b") | p.literal("c"), "") + p.literal("d");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= "a" "d"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("silent wrapped in tag emits nothing", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.literal("a") + p.tag("t", p.gbnf(p.literal("b"), ""));
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= "a"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("gbnf parser emits custom grammar", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.literal("a") + p.gbnf(p.literal("b"), "[a-z]+");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= "a" [a-z]+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("nested transparent wrappers get parenthesized", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.literal("x") + p.tag("outer", p.atomic(p.literal("a") | p.literal("b")));

View File

@@ -8580,7 +8580,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int nb : { 1, 3, 32, 75, }) {
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL}) {
if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue;
test_cases.emplace_back(new test_flash_attn_ext(
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));

View File

@@ -2118,6 +2118,31 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.tools({ amount_tool })
.expect(message_with_tool_calls("amount", R"({"orig": 1.5e10})"))
.run();
// Edge cases
tst.test(
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<channel|>")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.expect(message_assist)
.run();
tst.test(
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<|channel>thought\n<channel|>")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.expect(message_assist)
.run();
tst.test(
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<|channel>thought\n<channel|><channel|>")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.expect(message_assist)
.run();
tst.test(
"<|channel><|channel>thought\n<channel|>Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.expect(message_assist)
.run();
}
{
@@ -2576,6 +2601,215 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
expect(simple_assist_msg("CONTENT", "")).run();
}
// DeepSeek V3.2 tests - format uses DSML markup:
// <DSMLfunction_calls>
// <DSMLinvoke name="foo">
// <DSMLparameter name="bar" string="true|false">value</DSMLparameter>
// </DSMLinvoke>
// </DSMLfunction_calls>
// Reasoning uses <think>...</think>. The generation prompt ends in <think> (thinking mode)
// or <think></think> (non-thinking mode).
{
auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.2.jinja", detailed_debug);
// Pure content (non-thinking mode)
tst.test("Hello, world!\nWhat's up?")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.expect(message_assist)
.run();
// Thinking + content
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.expect(message_assist_thoughts)
.run();
// Thinking + tool call (single, string param)
tst.test(
"Let me check the time</think>\n\n"
"<DSMLfunction_calls>\n"
"<DSMLinvoke name=\"get_time\">\n"
"<DSMLparameter name=\"city\" string=\"true\">Tokyo</DSMLparameter>\n"
"</DSMLinvoke>\n"
"</DSMLfunction_calls>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ get_time_tool })
.expect(message_with_tool_calls_and_reasoning("get_time", R"({"city": "Tokyo"})", "Let me check the time"))
.run();
// Tool call without reasoning (non-thinking mode), integer param (string="false")
tst.test(
"<DSMLfunction_calls>\n"
"<DSMLinvoke name=\"special_function\">\n"
"<DSMLparameter name=\"arg1\" string=\"false\">1</DSMLparameter>\n"
"</DSMLinvoke>\n"
"</DSMLfunction_calls>")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
// Multiple parallel tool calls with reasoning
tst.test(
"Calling both</think>\n\n"
"<DSMLfunction_calls>\n"
"<DSMLinvoke name=\"get_time\">\n"
"<DSMLparameter name=\"city\" string=\"true\">Paris</DSMLparameter>\n"
"</DSMLinvoke>\n"
"<DSMLinvoke name=\"get_weather\">\n"
"<DSMLparameter name=\"city\" string=\"true\">Paris</DSMLparameter>\n"
"</DSMLinvoke>\n"
"</DSMLfunction_calls>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.parallel_tool_calls(true)
.tools({ get_time_tool, get_weather_tool })
.expect(message_with_reasoning_content_and_multiple_tool_calls(
"Calling both", "",
{ { "get_time", R"({"city": "Paris"})" }, { "get_weather", R"({"city": "Paris"})" } }))
.run();
// Tool call with content before tool calls
tst.test(
"Thinking about it</think>"
"Let me call the function.\n\n"
"<DSMLfunction_calls>\n"
"<DSMLinvoke name=\"special_function\">\n"
"<DSMLparameter name=\"arg1\" string=\"false\">1</DSMLparameter>\n"
"</DSMLinvoke>\n"
"</DSMLfunction_calls>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ special_function_tool })
.expect_reasoning("Thinking about it")
.expect_content("Let me call the function.")
.expect_tool_calls({
{ "special_function", R"({"arg1": 1})", {} },
})
.run();
// Tool call with negative number
tst.test(
"Test negative</think>\n\n"
"<DSMLfunction_calls>\n"
"<DSMLinvoke name=\"magic_int\">\n"
"<DSMLparameter name=\"ref\" string=\"false\">-14</DSMLparameter>\n"
"</DSMLinvoke>\n"
"</DSMLfunction_calls>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ magic_int_tool })
.expect_reasoning("Test negative")
.expect_tool_calls({
{ "magic_int", R"({"ref": -14})", {} },
})
.run();
// Tool call with decimal number
tst.test(
"Test decimal</think>\n\n"
"<DSMLfunction_calls>\n"
"<DSMLinvoke name=\"amount\">\n"
"<DSMLparameter name=\"orig\" string=\"false\">3.14</DSMLparameter>\n"
"</DSMLinvoke>\n"
"</DSMLfunction_calls>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ amount_tool })
.expect_reasoning("Test decimal")
.expect_tool_calls({
{ "amount", R"({"orig": 3.14})", {} },
})
.run();
// Tool call with boolean
tst.test(
"Test boolean</think>\n\n"
"<DSMLfunction_calls>\n"
"<DSMLinvoke name=\"toggle\">\n"
"<DSMLparameter name=\"enabled\" string=\"false\">true</DSMLparameter>\n"
"</DSMLinvoke>\n"
"</DSMLfunction_calls>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ toggle_tool })
.expect_reasoning("Test boolean")
.expect_tool_calls({
{ "toggle", R"({"enabled": true})", {} },
})
.run();
// Tool call with array parameter (JSON-formatted)
tst.test(
"Test array</think>\n\n"
"<DSMLfunction_calls>\n"
"<DSMLinvoke name=\"todo_list\">\n"
"<DSMLparameter name=\"todos\" string=\"false\">[\"buy milk\",\"walk dog\"]</DSMLparameter>\n"
"</DSMLinvoke>\n"
"</DSMLfunction_calls>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ todo_list })
.expect_reasoning("Test array")
.expect_tool_calls({
{ "todo_list", R"({"todos": ["buy milk", "walk dog"]})", {} },
})
.run();
// Tool call with object parameter (JSON-formatted)
tst.test(
"Test object</think>\n\n"
"<DSMLfunction_calls>\n"
"<DSMLinvoke name=\"set_config\">\n"
"<DSMLparameter name=\"config\" string=\"false\">{\"theme\":\"dark\",\"level\":2}</DSMLparameter>\n"
"</DSMLinvoke>\n"
"</DSMLfunction_calls>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ config_tool })
.expect_reasoning("Test object")
.expect_tool_calls({
{ "set_config", R"({"config": {"theme": "dark", "level": 2}})", {} },
})
.run();
// Edge case: empty reasoning
tst.test(
"</think>\n\n"
"<DSMLfunction_calls>\n"
"<DSMLinvoke name=\"get_time\">\n"
"<DSMLparameter name=\"city\" string=\"true\">XYZCITY</DSMLparameter>\n"
"</DSMLinvoke>\n"
"</DSMLfunction_calls>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ get_time_tool })
.expect(message_with_tool_calls("get_time", R"({"city": "XYZCITY"})"))
.run();
// Edge case: tool call with multiple params (mixed types, string first)
tst.test(
"Multi-arg call</think>\n\n"
"<DSMLfunction_calls>\n"
"<DSMLinvoke name=\"magic_int\">\n"
"<DSMLparameter name=\"ref\" string=\"false\">42</DSMLparameter>\n"
"<DSMLparameter name=\"name\" string=\"true\">foo bar</DSMLparameter>\n"
"</DSMLinvoke>\n"
"</DSMLfunction_calls>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ magic_int_tool })
.expect_reasoning("Multi-arg call")
.expect_tool_calls({
{ "magic_int", R"({"ref": 42, "name": "foo bar"})", {} },
})
.run();
}
// GLM-4.6 tests - format: <tool_call>function_name\n<arg_key>...</arg_key>\n<arg_value>...</arg_value>\n</tool_call>
{
auto tst = peg_tester("models/templates/GLM-4.6.jinja", detailed_debug);

View File

@@ -98,6 +98,7 @@ add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M"
add_test_audio "ggml-org/LFM2-Audio-1.5B-GGUF:Q8_0"
add_test_audio "ggml-org/gemma-4-E2B-it-GGUF:Q8_0" --jinja
add_test_audio "ggml-org/Qwen3-ASR-0.6B-GGUF:Q8_0"
# to test the big models, run: ./tests.sh big
if [ "$RUN_BIG_TESTS" = true ]; then

View File

@@ -926,7 +926,8 @@ void server_models_routes::init_routes() {
res_ok(res, {
// TODO: add support for this on web UI
{"role", "router"},
{"max_instances", 4}, // dummy value for testing
{"max_instances", params.models_max},
{"models_autoload", params.models_autoload},
// this is a dummy response to make sure webui doesn't break
{"model_alias", "llama-server"},
{"model_path", "none"},
@@ -935,6 +936,7 @@ void server_models_routes::init_routes() {
{"n_ctx", 0},
}},
{"webui_settings", webui_settings},
{"build_info", build_info},
});
return res;
}

View File

@@ -9,6 +9,19 @@ def create_server():
server = ServerPreset.router()
def test_router_props():
global server
server.models_max = 2
server.no_models_autoload = True
server.start()
res = server.make_request("GET", "/props")
assert res.status_code == 200
assert res.body["role"] == "router"
assert res.body["max_instances"] == 2
assert res.body["models_autoload"] is False
assert res.body["build_info"].startswith("b")
@pytest.mark.parametrize(
"model,success",
[