mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-16 16:27:32 +03:00
Compare commits
6 Commits
mtmd-video
...
b8785
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6a6780a232 | ||
|
|
e489a5ca0e | ||
|
|
e21cdc11a0 | ||
|
|
e974923698 | ||
|
|
1c0d9081fd | ||
|
|
a8bad3842e |
2
.github/workflows/close-issue.yml
vendored
2
.github/workflows/close-issue.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap"
|
||||
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap,security"
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
|
||||
197
common/chat.cpp
197
common/chat.cpp
@@ -1091,6 +1091,14 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
|
||||
if (inputs.add_generation_prompt && string_ends_with(data.prompt, "<turn|>\n")) {
|
||||
// This may happen if the model generates content + tool_call, the
|
||||
// template does not add the model's next turn and confuses the model
|
||||
// from emitting its proper reasoning token sequence.
|
||||
data.prompt += "<|turn>model\n";
|
||||
}
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<|channel>thought";
|
||||
@@ -1118,7 +1126,8 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
p.rule("thought", p.content(p.literal("<|channel>thought") + p.space() + p.until("<channel|>") + p.literal("<channel|>")));
|
||||
}
|
||||
|
||||
auto thought = (p.peek(p.literal("<|channel>")) + p.ref("thought")) | p.negate(p.literal("<|channel>"));
|
||||
auto consume_empty_channels = p.gbnf(p.zero_or_more(p.literal("<|channel>") + p.negate(p.literal("thought"))), "");
|
||||
auto thought = (p.peek(p.literal("<|channel>")) + consume_empty_channels + p.ref("thought")) | p.negate(p.literal("<|channel>"));
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.literal("```json") <<
|
||||
@@ -1182,12 +1191,16 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
/* max = */ inputs.parallel_tool_calls ? -1 : 1
|
||||
));
|
||||
|
||||
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<|tool_call>"})));
|
||||
auto scan_to_toolcall = p.rule("scan-to-toolcall", p.until("<|tool_call>"));
|
||||
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<channel|>", "<|tool_call>"})));
|
||||
auto message = p.rule("message", thought + content);
|
||||
return start + p.zero_or_more(message) + tool_call;
|
||||
return start + p.zero_or_more(message) + scan_to_toolcall + tool_call;
|
||||
}
|
||||
|
||||
auto content = p.rule("content", p.content(p.until("<|channel>")));
|
||||
// Gemma 4 may emit an extra <|channel>thought\n<channel|> at the end of the content. It may
|
||||
// also emit a single trailing <channel|> token. Consume all complete reasoning blocks and
|
||||
// then stop at the first unmatched <channel|> token.
|
||||
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<channel|>"})));
|
||||
auto message = p.rule("message", thought + content);
|
||||
return start + p.one_or_more(message);
|
||||
});
|
||||
@@ -1656,6 +1669,173 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_deepseek_v3_2(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<think>";
|
||||
data.thinking_end_tag = "</think>";
|
||||
data.preserved_tokens = {
|
||||
"|DSML|",
|
||||
"<think>",
|
||||
"</think>",
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
|
||||
const std::string DSML = "|DSML|";
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
const std::string FC_START = "<" + DSML + "function_calls>";
|
||||
const std::string FC_END = "</" + DSML + "function_calls>";
|
||||
const std::string INVOKE_START = "<" + DSML + "invoke";
|
||||
const std::string INVOKE_END = "</" + DSML + "invoke>";
|
||||
const std::string PARAM_START = "<" + DSML + "parameter";
|
||||
const std::string PARAM_END = "</" + DSML + "parameter>";
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
if (extract_reasoning && inputs.enable_thinking) {
|
||||
reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END);
|
||||
} else if (extract_reasoning) {
|
||||
// Thinking disabled but reasoning extraction requested: the generation prompt
|
||||
// contains an empty <think></think> pair that must still be consumed.
|
||||
reasoning = p.optional(p.literal(THINK_START) + p.until(THINK_END) + p.literal(THINK_END));
|
||||
}
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.rule("response-format",
|
||||
p.literal("```json") + p.space() +
|
||||
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)) +
|
||||
p.space() + p.literal("```"));
|
||||
return generation_prompt + reasoning + response_format + end;
|
||||
}
|
||||
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return generation_prompt + reasoning + p.content(p.rest()) + end;
|
||||
}
|
||||
|
||||
auto tool_choice = p.choice();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto params = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
const auto & props = params.contains("properties") ? params.at("properties") : json::object();
|
||||
|
||||
std::set<std::string> required;
|
||||
if (params.contains("required")) {
|
||||
params.at("required").get_to(required);
|
||||
}
|
||||
|
||||
auto schema_info = common_schema_info();
|
||||
schema_info.resolve_refs(params);
|
||||
|
||||
std::vector<common_peg_parser> required_parsers;
|
||||
std::vector<common_peg_parser> optional_parsers;
|
||||
for (const auto & [param_name, param_schema] : props.items()) {
|
||||
bool is_required = required.find(param_name) != required.end();
|
||||
bool is_string = schema_info.resolves_to_string(param_schema);
|
||||
|
||||
auto arg = p.tool_arg(
|
||||
p.tool_arg_open(
|
||||
p.literal(PARAM_START + " name=\"") +
|
||||
p.tool_arg_name(p.literal(param_name)) +
|
||||
p.literal("\" string=\"" + std::string(is_string ? "true" : "false") + "\">")) +
|
||||
(is_string
|
||||
? p.tool_arg_string_value(p.until(PARAM_END))
|
||||
: p.tool_arg_json_value(p.schema(p.json(),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema, false))) +
|
||||
p.tool_arg_close(p.literal(PARAM_END)));
|
||||
|
||||
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
|
||||
if (is_required) {
|
||||
required_parsers.push_back(named_arg);
|
||||
} else {
|
||||
optional_parsers.push_back(named_arg);
|
||||
}
|
||||
}
|
||||
|
||||
common_peg_parser args_seq = p.eps();
|
||||
for (size_t i = 0; i < required_parsers.size(); i++) {
|
||||
if (i > 0) {
|
||||
args_seq = args_seq + p.space();
|
||||
}
|
||||
args_seq = args_seq + required_parsers[i];
|
||||
}
|
||||
|
||||
if (!optional_parsers.empty()) {
|
||||
common_peg_parser any_opt = p.choice();
|
||||
for (const auto & opt : optional_parsers) {
|
||||
any_opt |= opt;
|
||||
}
|
||||
args_seq = args_seq + p.repeat(p.space() + any_opt, 0, -1);
|
||||
}
|
||||
|
||||
common_peg_parser invoke_body = args_seq;
|
||||
auto func_parser = p.tool(
|
||||
p.tool_open(p.literal(INVOKE_START + " name=\"") +
|
||||
p.tool_name(p.literal(name)) + p.literal("\">\n")) +
|
||||
invoke_body + p.space() +
|
||||
p.tool_close(p.literal(INVOKE_END)));
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
auto require_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
common_peg_parser tool_calls = p.eps();
|
||||
if (inputs.parallel_tool_calls) {
|
||||
tool_calls = p.trigger_rule("tool-call",
|
||||
p.literal(FC_START) + p.space() + tool_choice +
|
||||
p.zero_or_more(p.space() + tool_choice) + p.space() + p.literal(FC_END));
|
||||
} else {
|
||||
tool_calls = p.trigger_rule("tool-call",
|
||||
p.literal(FC_START) + p.space() + tool_choice + p.space() + p.literal(FC_END));
|
||||
}
|
||||
|
||||
if (!require_tools) {
|
||||
tool_calls = p.optional(tool_calls);
|
||||
}
|
||||
|
||||
auto content_before_tools = p.content(p.until(FC_START));
|
||||
return generation_prompt + reasoning + content_before_tools + tool_calls + end;
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, FC_START },
|
||||
};
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
namespace workaround {
|
||||
|
||||
static void map_developer_role_to_system(json & messages) {
|
||||
@@ -1927,6 +2107,15 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
return common_chat_params_init_gigachat_v3(tmpl, params);
|
||||
}
|
||||
|
||||
// DeepSeek V3.2 format detection: template defines dsml_token and uses it for tool calls.
|
||||
// The template source contains the token as a variable assignment, not as a literal in markup.
|
||||
if (src.find("dsml_token") != std::string::npos &&
|
||||
src.find("function_calls") != std::string::npos &&
|
||||
src.find("DSML") != std::string::npos) {
|
||||
LOG_DBG("Using specialized template: DeepSeek V3.2\n");
|
||||
return common_chat_params_init_deepseek_v3_2(tmpl, params);
|
||||
}
|
||||
|
||||
// Gemma4 format detection
|
||||
if (src.find("'<|tool_call>call:'") != std::string::npos) {
|
||||
if (src.find("{#- OpenAI Chat Completions:") == std::string::npos) {
|
||||
|
||||
@@ -890,6 +890,10 @@ struct parser_executor {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
|
||||
return arena.parse(p.child, ctx, start_pos);
|
||||
}
|
||||
};
|
||||
|
||||
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
|
||||
@@ -957,7 +961,8 @@ void common_peg_arena::resolve_refs() {
|
||||
std::is_same_v<T, common_peg_and_parser> ||
|
||||
std::is_same_v<T, common_peg_not_parser> ||
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser>) {
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
@@ -1036,6 +1041,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
|
||||
return "Not(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
|
||||
return "Atomic(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
|
||||
return "Any";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
|
||||
@@ -1565,6 +1572,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
|
||||
std::is_same_v<T, common_peg_not_parser> ||
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser> ||
|
||||
std::is_same_v<T, common_peg_schema_parser>) {
|
||||
visit(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
@@ -1651,10 +1659,13 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
} else if constexpr (std::is_same_v<T, common_peg_sequence_parser>) {
|
||||
std::string s;
|
||||
for (const auto & child : p.children) {
|
||||
auto child_gbnf = to_gbnf(child);
|
||||
if (child_gbnf.empty()) {
|
||||
continue;
|
||||
}
|
||||
if (!s.empty()) {
|
||||
s += " ";
|
||||
}
|
||||
auto child_gbnf = to_gbnf(child);
|
||||
const auto & child_parser = effective_parser(child);
|
||||
if (std::holds_alternative<common_peg_choice_parser>(child_parser) ||
|
||||
std::holds_alternative<common_peg_sequence_parser>(child_parser)) {
|
||||
@@ -1754,6 +1765,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
return to_gbnf(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
|
||||
return to_gbnf(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return p.grammar;
|
||||
} else {
|
||||
static_assert(is_always_false_v<T>);
|
||||
}
|
||||
@@ -1888,6 +1901,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
|
||||
{"child", p.child},
|
||||
{"tag", p.tag}
|
||||
};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
|
||||
}
|
||||
}, variant);
|
||||
}
|
||||
@@ -2050,6 +2065,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
|
||||
};
|
||||
}
|
||||
|
||||
if (type == "gbnf") {
|
||||
if (!j.contains("child") || !j.contains("grammar")) {
|
||||
throw std::runtime_error("gbnf parser missing required fields");
|
||||
}
|
||||
return common_peg_gbnf_parser{
|
||||
j["child"].get<common_peg_parser_id>(),
|
||||
j["grammar"].get<std::string>(),
|
||||
};
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unknown parser type: " + type);
|
||||
}
|
||||
|
||||
|
||||
@@ -270,6 +270,11 @@ struct common_peg_tag_parser {
|
||||
std::string tag;
|
||||
};
|
||||
|
||||
struct common_peg_gbnf_parser {
|
||||
common_peg_parser_id child;
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
// Variant holding all parser types
|
||||
using common_peg_parser_variant = std::variant<
|
||||
common_peg_epsilon_parser,
|
||||
@@ -290,7 +295,8 @@ using common_peg_parser_variant = std::variant<
|
||||
common_peg_rule_parser,
|
||||
common_peg_ref_parser,
|
||||
common_peg_atomic_parser,
|
||||
common_peg_tag_parser
|
||||
common_peg_tag_parser,
|
||||
common_peg_gbnf_parser
|
||||
>;
|
||||
|
||||
class common_peg_arena {
|
||||
@@ -504,6 +510,10 @@ class common_peg_parser_builder {
|
||||
// Unlike rules, you can tag multiple nodes with the same tag.
|
||||
common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); }
|
||||
|
||||
// Wraps a child parser but emits a custom GBNF grammar string instead of
|
||||
// the child's grammar. Parsing delegates entirely to the child.
|
||||
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
|
||||
|
||||
void set_root(const common_peg_parser & p);
|
||||
|
||||
common_peg_arena build();
|
||||
|
||||
@@ -114,6 +114,10 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
||||
|
||||
# Mistral's Voxtral
|
||||
(tool_name) -hf ggml-org/Voxtral-Mini-3B-2507-GGUF
|
||||
|
||||
# Qwen3-ASR
|
||||
(tool_name) -hf ggml-org/Qwen3-ASR-0.6B-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen3-ASR-1.7B-GGUF
|
||||
```
|
||||
|
||||
**Mixed modalities**:
|
||||
@@ -124,6 +128,11 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
||||
(tool_name) -hf ggml-org/Qwen2.5-Omni-3B-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen2.5-Omni-7B-GGUF
|
||||
|
||||
# Qwen3 Omni
|
||||
# Capabilities: audio input, vision input
|
||||
(tool_name) -hf ggml-org/Qwen3-Omni-30B-A3B-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen3-Omni-30B-A3B-Thinking-GGUF
|
||||
|
||||
# Gemma 4
|
||||
# Capabilities: audio input, vision input
|
||||
(tool_name) -hf ggml-org/gemma-4-E2B-it-GGUF
|
||||
|
||||
@@ -3079,6 +3079,10 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
||||
case GGML_TYPE_MXFP4:
|
||||
lut_size = 4*16;
|
||||
break;
|
||||
case GGML_TYPE_NVFP4:
|
||||
// Same kvalues budget as MXFP4 plus ue4m3_fp32_lut[128] (types.glsl, DATA_A_NVFP4).
|
||||
lut_size = 4*16 + 128u * (uint32_t)sizeof(float);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -3558,6 +3562,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_NVFP4], matmul_nvfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
|
||||
GGML_ASSERT(device->subgroup_ballot);
|
||||
|
||||
@@ -3588,6 +3593,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
||||
#undef CREATE_MM
|
||||
#undef CREATE_MM2
|
||||
} else
|
||||
@@ -3651,6 +3657,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
} else {
|
||||
CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
@@ -3674,6 +3681,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
}
|
||||
|
||||
GGML_ASSERT(device->subgroup_ballot);
|
||||
@@ -3708,6 +3716,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
||||
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
||||
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
||||
#undef CREATE_MM2
|
||||
#undef CREATE_MM
|
||||
} else
|
||||
@@ -3773,6 +3782,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
@@ -3819,6 +3829,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
@@ -3864,6 +3875,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
@@ -3939,6 +3951,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
@@ -3983,6 +3996,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_subgroup_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
} else {
|
||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
@@ -4010,6 +4024,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
}
|
||||
}
|
||||
// reusing CREATE_MM from the fp32 path
|
||||
@@ -4108,6 +4123,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f32_f32", arr_dmmv_nvfp4_f32_f32_len[reduc16], arr_dmmv_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
@@ -4133,6 +4149,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f16_f32", arr_dmmv_nvfp4_f16_f32_len[reduc16], arr_dmmv_nvfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
@@ -4184,6 +4201,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_NVFP4], "mul_mat_vec_id_nvfp4_f32", arr_dmmv_id_nvfp4_f32_f32_len[reduc16], arr_dmmv_id_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
@@ -4239,6 +4257,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_NVFP4], "dequant_nvfp4", dequant_nvfp4_len, dequant_nvfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
|
||||
// get_rows
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||
@@ -4265,6 +4284,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_NVFP4], "get_rows_nvfp4", get_rows_nvfp4_len, get_rows_nvfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||
@@ -4291,6 +4311,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||
@@ -6089,6 +6110,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
@@ -6161,6 +6183,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
@@ -6227,6 +6250,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
@@ -6318,6 +6342,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
@@ -6387,6 +6412,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
@@ -15373,6 +15399,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
@@ -15488,6 +15515,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
case GGML_TYPE_I32:
|
||||
return true;
|
||||
default:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "generic_unary_head.glsl"
|
||||
#include "dequant_funcs.glsl"
|
||||
|
||||
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
|
||||
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4)
|
||||
// 16 invocations needed for init_iq_shmem
|
||||
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
|
||||
#else
|
||||
|
||||
@@ -450,6 +450,25 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint sub = iqs >> 4;
|
||||
const float d = ue4m3_to_fp32(data_a[a_offset + ib].d[sub]);
|
||||
const uint j = iqs & 7;
|
||||
const uint shift = (iqs & 8) >> 1; // 0 or 4
|
||||
const uint vui0 = uint(data_a[a_offset + ib].qs[sub * 8u + j]);
|
||||
const uint vui1 = uint(data_a[a_offset + ib].qs[sub * 8u + j + 1]);
|
||||
const uint qs0 = (vui0 >> shift) & 0xF;
|
||||
const uint qs1 = (vui1 >> shift) & 0xF;
|
||||
return vec2(float(kvalues_mxfp4[qs0]), float(kvalues_mxfp4[qs1])) * d * 0.5;
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const vec2 v0 = dequantize(ib, iqs, a_offset);
|
||||
const vec2 v1 = dequantize(ib, iqs + 2u, a_offset);
|
||||
return vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(0, 0);
|
||||
@@ -484,6 +503,12 @@ vec2 get_dm(uint ib, uint a_offset) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(1.0, 0.0);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm);
|
||||
|
||||
@@ -697,6 +697,24 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4 {
|
||||
block_nvfp4 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint sub = (idx & 0x30) >> 4;
|
||||
const uint iqs = ((idx & 0x30) >> 1) + (idx & 0x7);
|
||||
const uint shift = (idx & 0x8) >> 1;
|
||||
const float d = ue4m3_to_fp32(bl.block.d[sub]);
|
||||
uint qs = uint(bl.block.qs[iqs]);
|
||||
qs = (qs >> shift) & 0xF;
|
||||
return float16_t(kvalues_mxfp4[qs] * d * 0.5);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q1_0)
|
||||
#define dequantFuncA dequantFuncQ1_0
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
@@ -743,6 +761,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
|
||||
#define dequantFuncA dequantFuncIQ4_NL
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
#define dequantFuncA dequantFuncMXFP4
|
||||
#elif defined(DATA_A_NVFP4)
|
||||
#define dequantFuncA dequantFuncNVFP4
|
||||
#elif defined(DATA_A_F32)
|
||||
#define dequantFuncA dequantFuncF32
|
||||
#endif
|
||||
|
||||
32
ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp
Normal file
32
ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp
Normal file
@@ -0,0 +1,32 @@
|
||||
#version 450
|
||||
|
||||
#include "dequant_head.glsl"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_nvfp4 data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
const uint tid = gl_LocalInvocationID.x % 64;
|
||||
const uint sub = tid / 16;
|
||||
const uint ir = tid % 16;
|
||||
const uint ib = 16 * i + ir;
|
||||
if (ib >= p.nel / 64) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint q_idx = 8 * sub;
|
||||
const uint b_idx = 1024 * i + 64 * ir + 16 * sub;
|
||||
|
||||
const float d = ue4m3_to_fp32(data_a[ib].d[sub]);
|
||||
|
||||
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
||||
data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
|
||||
data_b[b_idx + l + 8] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
|
||||
}
|
||||
}
|
||||
@@ -501,6 +501,23 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
kvalues_mxfp4[vui2 & 0xF] * d);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d,
|
||||
kvalues_mxfp4[vui2 >> 4] * d);
|
||||
#elif defined(DATA_A_NVFP4)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
// lo and hi nibbles are 8 elements apart, which doesn't quite line up with
|
||||
// how the thread mapping and buf_idx calculation works for other types.
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row & 3) + (row & ~3) * 2;
|
||||
|
||||
const uint ib = idx / 16u;
|
||||
const uint sub = (idx & 0xC) >> 2;
|
||||
const uint iqs = (idx & 0xF) * 2;
|
||||
const float d = ue4m3_to_fp32(data_a[ib].d[sub]) * 0.5;
|
||||
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d,
|
||||
kvalues_mxfp4[vui2 & 0xF] * d);
|
||||
buf_a[buf_idx + 4] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d,
|
||||
kvalues_mxfp4[vui2 >> 4] * d);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -1713,6 +1713,22 @@ struct block_mxfp4
|
||||
#define A_TYPE block_mxfp4
|
||||
#endif
|
||||
|
||||
#define QUANT_K_NVFP4 64
|
||||
#define QUANT_R_NVFP4 1
|
||||
|
||||
struct block_nvfp4
|
||||
{
|
||||
uint8_t d[QUANT_K_NVFP4 / 16];
|
||||
uint8_t qs[QUANT_K_NVFP4 / 2];
|
||||
};
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
#define QUANT_K QUANT_K_NVFP4
|
||||
#define QUANT_R QUANT_R_NVFP4
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_nvfp4
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
|
||||
const int8_t kvalues_iq4nl_const[16] = {
|
||||
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
|
||||
@@ -1732,7 +1748,7 @@ void init_iq_shmem(uvec3 wgsize)
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
#if defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4)
|
||||
const int8_t kvalues_mxfp4_const[16] = {
|
||||
int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12),
|
||||
int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12),
|
||||
@@ -1740,6 +1756,24 @@ const int8_t kvalues_mxfp4_const[16] = {
|
||||
|
||||
shared int8_t kvalues_mxfp4[16];
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
// UE4M3 scale in NVFP4 blocks use only 7 bits; sign (bit 7) is always zero.
|
||||
shared float ue4m3_fp32_lut[128];
|
||||
|
||||
float ue4m3_to_fp32_build(uint u) {
|
||||
if (u == 0u || u == 127u) {
|
||||
return 0.0;
|
||||
}
|
||||
const uint exp = (u >> 3) & 15u;
|
||||
const uint man = u & 7u;
|
||||
if (exp == 0u) {
|
||||
return float(man) * (1.0 / 512.0);
|
||||
}
|
||||
const uint bits = (exp + 120u) << 23 | (man << 20);
|
||||
return uintBitsToFloat(bits);
|
||||
}
|
||||
#endif
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
@@ -1747,6 +1781,11 @@ void init_iq_shmem(uvec3 wgsize)
|
||||
for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) {
|
||||
kvalues_mxfp4[i] = kvalues_mxfp4_const[i];
|
||||
}
|
||||
#if defined(DATA_A_NVFP4)
|
||||
for (uint i = gl_LocalInvocationIndex.x; i < 128u; i += wgsize.x) {
|
||||
ue4m3_fp32_lut[i] = ue4m3_to_fp32_build(i);
|
||||
}
|
||||
#endif
|
||||
barrier();
|
||||
}
|
||||
#endif
|
||||
@@ -1783,6 +1822,12 @@ float e8m0_to_fp32(uint8_t x) {
|
||||
return uintBitsToFloat(bits);
|
||||
}
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
float ue4m3_to_fp32(uint8_t x) {
|
||||
return ue4m3_fp32_lut[uint(x)];
|
||||
}
|
||||
#endif
|
||||
|
||||
#if BDA
|
||||
|
||||
#extension GL_EXT_buffer_reference : enable
|
||||
|
||||
@@ -66,6 +66,7 @@ const std::vector<std::string> type_names = {
|
||||
"iq4_xs",
|
||||
"iq4_nl",
|
||||
"mxfp4",
|
||||
"nvfp4",
|
||||
"bf16",
|
||||
};
|
||||
|
||||
@@ -556,7 +557,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
std::string load_vec_quant = "2";
|
||||
if ((tname == "q1_0") || (tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
|
||||
load_vec_quant = "8";
|
||||
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4"))
|
||||
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4") || (tname == "nvfp4"))
|
||||
load_vec_quant = "4";
|
||||
|
||||
if (tname == "bf16") {
|
||||
|
||||
141
models/templates/deepseek-ai-DeepSeek-V3.2.jinja
Normal file
141
models/templates/deepseek-ai-DeepSeek-V3.2.jinja
Normal file
@@ -0,0 +1,141 @@
|
||||
{%- if not add_generation_prompt is defined -%}
|
||||
{%- set add_generation_prompt = false -%}
|
||||
{%- endif -%}
|
||||
{%- if not thinking is defined -%}
|
||||
{%- if enable_thinking is defined -%}
|
||||
{%- set thinking = enable_thinking -%}
|
||||
{%- else -%}
|
||||
{%- set thinking = false -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- set dsml_token = '|DSML|' -%}
|
||||
{%- set thinking_start_token = '<think>' -%}
|
||||
{%- set thinking_end_token = '</think>' -%}
|
||||
{%- set tools_header = '## Tools\n\nYou have access to a set of tools you can use to answer the user\'s question.\nYou can invoke functions by writing a "<' + dsml_token + 'function_calls>" block like the following as part of your reply to the user:\n<' + dsml_token + 'function_calls>\n<' + dsml_token + 'invoke name="$FUNCTION_NAME">\n<' + dsml_token + 'parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</' + dsml_token + 'parameter>\n...\n</' + dsml_token + 'invoke>\n<' + dsml_token + 'invoke name="$FUNCTION_NAME2">\n...\n</' + dsml_token + 'invoke>\n</' + dsml_token + 'function_calls>\n\nString and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects).\n\nIf the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example:\n\n<' + dsml_token + 'function_calls>\n...\n</' + dsml_token + 'function_calls>\n\n<function_results>\n...\n</function_results>\n\n' + thinking_start_token + '...thinking about results' + thinking_end_token + '\n\nHere are the functions available in JSONSchema format:\n<functions>\n' -%}
|
||||
{%- set tools_footer = '</functions>\n' -%}
|
||||
{%- set ns = namespace(system_prompt='', is_first_sp=true) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{%- if ns.is_first_sp -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + (message['content'] or '') -%}
|
||||
{%- set ns.is_first_sp = false -%}
|
||||
{%- else -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + (message['content'] or '') -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if tools is defined and tools -%}
|
||||
{%- set ts = namespace(schemas='') -%}
|
||||
{%- for tool in tools -%}
|
||||
{%- if tool['type'] == 'function' -%}
|
||||
{%- set ts.schemas = ts.schemas + (tool['function'] | tojson) + '\n' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if ns.system_prompt -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + tools_header + ts.schemas + tools_footer -%}
|
||||
{%- else -%}
|
||||
{%- set ns.system_prompt = tools_header + ts.schemas + tools_footer -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{{- bos_token -}}
|
||||
{{- ns.system_prompt -}}
|
||||
{%- set last_user_idx = namespace(value=-1) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' or message['role'] == 'developer' -%}
|
||||
{%- set last_user_idx.value = loop.index0 -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set state = namespace(pending_asst_marker=false, pending_tool_marker=false) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{- '<|User|>' + (message['content'] or '') -}}
|
||||
{%- set state.pending_asst_marker = true -%}
|
||||
{%- set state.pending_tool_marker = false -%}
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
{%- set is_after_last_user = loop.index0 > last_user_idx.value -%}
|
||||
{%- if state.pending_asst_marker -%}
|
||||
{{- '<|Assistant|>' -}}
|
||||
{%- if is_after_last_user and thinking -%}
|
||||
{{- thinking_start_token -}}
|
||||
{%- if message['reasoning_content'] is defined and message['reasoning_content'] -%}
|
||||
{{- message['reasoning_content'] -}}
|
||||
{%- endif -%}
|
||||
{{- thinking_end_token -}}
|
||||
{%- else -%}
|
||||
{{- thinking_end_token -}}
|
||||
{%- endif -%}
|
||||
{%- elif state.pending_tool_marker -%}
|
||||
{%- if is_after_last_user and thinking -%}
|
||||
{{- '\n\n' + thinking_start_token -}}
|
||||
{%- if message['reasoning_content'] is defined and message['reasoning_content'] -%}
|
||||
{{- message['reasoning_content'] -}}
|
||||
{%- endif -%}
|
||||
{{- thinking_end_token -}}
|
||||
{%- else -%}
|
||||
{{- '\n\n' + thinking_end_token -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- set state.pending_asst_marker = false -%}
|
||||
{%- set state.pending_tool_marker = false -%}
|
||||
{%- if message['content'] is defined and message['content'] -%}
|
||||
{{- message['content'] -}}
|
||||
{%- endif -%}
|
||||
{%- if message['tool_calls'] -%}
|
||||
{{- '\n\n<' + dsml_token + 'function_calls>\n' -}}
|
||||
{%- for tool in message['tool_calls'] -%}
|
||||
{%- set func = tool['function'] -%}
|
||||
{{- '<' + dsml_token + 'invoke name="' + func['name'] + '">\n' -}}
|
||||
{%- set args = func['arguments'] -%}
|
||||
{%- if args is string -%}
|
||||
{%- set args = args | from_json -%}
|
||||
{%- endif -%}
|
||||
{%- for key, val in args.items() -%}
|
||||
{%- if val is string -%}
|
||||
{{- '<' + dsml_token + 'parameter name="' + key + '" string="true">' + val + '</' + dsml_token + 'parameter>\n' -}}
|
||||
{%- else -%}
|
||||
{{- '<' + dsml_token + 'parameter name="' + key + '" string="false">' + (val | tojson) + '</' + dsml_token + 'parameter>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '</' + dsml_token + 'invoke>\n' -}}
|
||||
{%- endfor -%}
|
||||
{{- '</' + dsml_token + 'function_calls>' -}}
|
||||
{%- endif -%}
|
||||
{{- '<|end▁of▁sentence|>' -}}
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
{%- set outer_index = loop.index0 -%}
|
||||
{%- set assistant_idx = namespace(value=-1) -%}
|
||||
{%- for prev_msg in messages -%}
|
||||
{%- if prev_msg['role'] == 'assistant' and prev_msg['tool_calls'] and loop.index0 < outer_index -%}
|
||||
{%- set assistant_idx.value = loop.index0 -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set call_order = outer_index - assistant_idx.value -%}
|
||||
{%- set assistant_msg = messages[assistant_idx.value] -%}
|
||||
{%- set tool_call_count = assistant_msg['tool_calls'] | length -%}
|
||||
{%- if call_order == 1 -%}
|
||||
{{- '\n\n<function_results>' -}}
|
||||
{%- endif -%}
|
||||
{{- '\n<result>' + (message['content'] or '') + '</result>' -}}
|
||||
{%- if call_order == tool_call_count -%}
|
||||
{{- '\n</function_results>' -}}
|
||||
{%- set state.pending_asst_marker = false -%}
|
||||
{%- set state.pending_tool_marker = true -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if state.pending_asst_marker -%}
|
||||
{{- '<|Assistant|>' -}}
|
||||
{%- if thinking -%}
|
||||
{{- thinking_start_token -}}
|
||||
{%- else -%}
|
||||
{{- thinking_start_token + thinking_end_token -}}
|
||||
{%- endif -%}
|
||||
{%- elif state.pending_tool_marker -%}
|
||||
{%- if thinking -%}
|
||||
{{- '\n\n' + thinking_start_token -}}
|
||||
{%- else -%}
|
||||
{{- '\n\n' + thinking_start_token + thinking_end_token -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
@@ -258,6 +258,66 @@ void test_gbnf_generation(testing &t) {
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("silent parser emits nothing in gbnf", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.gbnf(p.literal("world"), "");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("silent choice inside sequence emits nothing", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("a") + p.gbnf(p.literal("b") | p.literal("c"), "") + p.literal("d");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a" "d"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("silent wrapped in tag emits nothing", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("a") + p.tag("t", p.gbnf(p.literal("b"), ""));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("gbnf parser emits custom grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("a") + p.gbnf(p.literal("b"), "[a-z]+");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a" [a-z]+
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("nested transparent wrappers get parenthesized", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("x") + p.tag("outer", p.atomic(p.literal("a") | p.literal("b")));
|
||||
|
||||
@@ -2118,6 +2118,31 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.tools({ amount_tool })
|
||||
.expect(message_with_tool_calls("amount", R"({"orig": 1.5e10})"))
|
||||
.run();
|
||||
|
||||
// Edge cases
|
||||
tst.test(
|
||||
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<channel|>")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<|channel>thought\n<channel|>")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<|channel>thought\n<channel|><channel|>")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"<|channel><|channel>thought\n<channel|>Hello, world!\nWhat's up?")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
}
|
||||
|
||||
{
|
||||
@@ -2576,6 +2601,215 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
expect(simple_assist_msg("CONTENT", "")).run();
|
||||
}
|
||||
|
||||
// DeepSeek V3.2 tests - format uses DSML markup:
|
||||
// <|DSML|function_calls>
|
||||
// <|DSML|invoke name="foo">
|
||||
// <|DSML|parameter name="bar" string="true|false">value</|DSML|parameter>
|
||||
// </|DSML|invoke>
|
||||
// </|DSML|function_calls>
|
||||
// Reasoning uses <think>...</think>. The generation prompt ends in <think> (thinking mode)
|
||||
// or <think></think> (non-thinking mode).
|
||||
{
|
||||
auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.2.jinja", detailed_debug);
|
||||
|
||||
// Pure content (non-thinking mode)
|
||||
tst.test("Hello, world!\nWhat's up?")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
// Thinking + content
|
||||
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.expect(message_assist_thoughts)
|
||||
.run();
|
||||
|
||||
// Thinking + tool call (single, string param)
|
||||
tst.test(
|
||||
"Let me check the time</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"get_time\">\n"
|
||||
"<|DSML|parameter name=\"city\" string=\"true\">Tokyo</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ get_time_tool })
|
||||
.expect(message_with_tool_calls_and_reasoning("get_time", R"({"city": "Tokyo"})", "Let me check the time"))
|
||||
.run();
|
||||
|
||||
// Tool call without reasoning (non-thinking mode), integer param (string="false")
|
||||
tst.test(
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"special_function\">\n"
|
||||
"<|DSML|parameter name=\"arg1\" string=\"false\">1</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.run();
|
||||
|
||||
// Multiple parallel tool calls with reasoning
|
||||
tst.test(
|
||||
"Calling both</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"get_time\">\n"
|
||||
"<|DSML|parameter name=\"city\" string=\"true\">Paris</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"<|DSML|invoke name=\"get_weather\">\n"
|
||||
"<|DSML|parameter name=\"city\" string=\"true\">Paris</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.parallel_tool_calls(true)
|
||||
.tools({ get_time_tool, get_weather_tool })
|
||||
.expect(message_with_reasoning_content_and_multiple_tool_calls(
|
||||
"Calling both", "",
|
||||
{ { "get_time", R"({"city": "Paris"})" }, { "get_weather", R"({"city": "Paris"})" } }))
|
||||
.run();
|
||||
|
||||
// Tool call with content before tool calls
|
||||
tst.test(
|
||||
"Thinking about it</think>"
|
||||
"Let me call the function.\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"special_function\">\n"
|
||||
"<|DSML|parameter name=\"arg1\" string=\"false\">1</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ special_function_tool })
|
||||
.expect_reasoning("Thinking about it")
|
||||
.expect_content("Let me call the function.")
|
||||
.expect_tool_calls({
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with negative number
|
||||
tst.test(
|
||||
"Test negative</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"magic_int\">\n"
|
||||
"<|DSML|parameter name=\"ref\" string=\"false\">-14</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ magic_int_tool })
|
||||
.expect_reasoning("Test negative")
|
||||
.expect_tool_calls({
|
||||
{ "magic_int", R"({"ref": -14})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with decimal number
|
||||
tst.test(
|
||||
"Test decimal</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"amount\">\n"
|
||||
"<|DSML|parameter name=\"orig\" string=\"false\">3.14</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ amount_tool })
|
||||
.expect_reasoning("Test decimal")
|
||||
.expect_tool_calls({
|
||||
{ "amount", R"({"orig": 3.14})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with boolean
|
||||
tst.test(
|
||||
"Test boolean</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"toggle\">\n"
|
||||
"<|DSML|parameter name=\"enabled\" string=\"false\">true</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ toggle_tool })
|
||||
.expect_reasoning("Test boolean")
|
||||
.expect_tool_calls({
|
||||
{ "toggle", R"({"enabled": true})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with array parameter (JSON-formatted)
|
||||
tst.test(
|
||||
"Test array</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"todo_list\">\n"
|
||||
"<|DSML|parameter name=\"todos\" string=\"false\">[\"buy milk\",\"walk dog\"]</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ todo_list })
|
||||
.expect_reasoning("Test array")
|
||||
.expect_tool_calls({
|
||||
{ "todo_list", R"({"todos": ["buy milk", "walk dog"]})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with object parameter (JSON-formatted)
|
||||
tst.test(
|
||||
"Test object</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"set_config\">\n"
|
||||
"<|DSML|parameter name=\"config\" string=\"false\">{\"theme\":\"dark\",\"level\":2}</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ config_tool })
|
||||
.expect_reasoning("Test object")
|
||||
.expect_tool_calls({
|
||||
{ "set_config", R"({"config": {"theme": "dark", "level": 2}})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Edge case: empty reasoning
|
||||
tst.test(
|
||||
"</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"get_time\">\n"
|
||||
"<|DSML|parameter name=\"city\" string=\"true\">XYZCITY</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ get_time_tool })
|
||||
.expect(message_with_tool_calls("get_time", R"({"city": "XYZCITY"})"))
|
||||
.run();
|
||||
|
||||
// Edge case: tool call with multiple params (mixed types, string first)
|
||||
tst.test(
|
||||
"Multi-arg call</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"magic_int\">\n"
|
||||
"<|DSML|parameter name=\"ref\" string=\"false\">42</|DSML|parameter>\n"
|
||||
"<|DSML|parameter name=\"name\" string=\"true\">foo bar</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ magic_int_tool })
|
||||
.expect_reasoning("Multi-arg call")
|
||||
.expect_tool_calls({
|
||||
{ "magic_int", R"({"ref": 42, "name": "foo bar"})", {} },
|
||||
})
|
||||
.run();
|
||||
}
|
||||
|
||||
// GLM-4.6 tests - format: <tool_call>function_name\n<arg_key>...</arg_key>\n<arg_value>...</arg_value>\n</tool_call>
|
||||
{
|
||||
auto tst = peg_tester("models/templates/GLM-4.6.jinja", detailed_debug);
|
||||
|
||||
@@ -32,9 +32,6 @@ struct clip_graph {
|
||||
float kq_scale; // TODO: maybe move this to hparams
|
||||
const clip_flash_attn_type flash_attn_type;
|
||||
|
||||
// TODO [QWEN_VIDEO]: improve this in the future
|
||||
int nt = 1; // number of temporal dim, to be used by Qwen-VL models
|
||||
|
||||
ggml_context_ptr ctx0_ptr;
|
||||
ggml_context * ctx0;
|
||||
ggml_cgraph * gf;
|
||||
|
||||
@@ -448,7 +448,6 @@ struct clip_image_u8_batch {
|
||||
struct clip_image_f32_batch {
|
||||
std::vector<clip_image_f32_ptr> entries;
|
||||
bool is_audio = false;
|
||||
bool is_seq = true;
|
||||
|
||||
// for llava-uhd style models, we need to know the grid size
|
||||
// note: entries.size() == grid_x * grid_y + 1 (one overview image)
|
||||
@@ -459,7 +458,6 @@ struct clip_image_f32_batch {
|
||||
clip_image_f32_batch new_batch{
|
||||
/* entries */ {},
|
||||
/* is_audio */ is_audio,
|
||||
/* is_seq */ is_seq,
|
||||
/* grid_x */ grid_x,
|
||||
/* grid_y */ grid_y,
|
||||
};
|
||||
|
||||
@@ -515,7 +515,7 @@ ggml_tensor * clip_graph::build_inp() {
|
||||
}
|
||||
|
||||
ggml_tensor * clip_graph::build_inp_raw(int channels) {
|
||||
ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels, nt);
|
||||
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels);
|
||||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
return inp_raw;
|
||||
@@ -951,9 +951,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
GGML_ABORT("missing cgraph builder");
|
||||
}
|
||||
|
||||
// TODO [QWEN_VIDEO]: improve this in the future
|
||||
builder->nt = imgs.entries.size();
|
||||
|
||||
return builder->build();
|
||||
}
|
||||
|
||||
@@ -3045,11 +3042,10 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3
|
||||
bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) {
|
||||
const clip_image_f32_batch & imgs = *imgs_c_ptr;
|
||||
int batch_size = imgs.entries.size();
|
||||
bool support_seq = clip_model_supports_seq_input(ctx);
|
||||
|
||||
// TODO @ngxson : implement batch size > 1 as a loop
|
||||
// we don't need true batching support because the cgraph will gonna be big anyway
|
||||
if (batch_size != 1 && !support_seq) {
|
||||
if (batch_size != 1) {
|
||||
return false; // only support batch size of 1
|
||||
}
|
||||
|
||||
@@ -3121,8 +3117,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
// └─────┘ │
|
||||
// ──────┘ x B
|
||||
|
||||
// IMPORTANT: [QWEN_VIDEO] the batch dim is currently used for temporal dim in Qwen-VL models
|
||||
|
||||
for (size_t i = 0; i < imgs.entries.size(); i++) {
|
||||
const int nx = imgs.entries[i]->nx;
|
||||
const int ny = imgs.entries[i]->ny;
|
||||
@@ -3753,17 +3747,6 @@ bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
bool clip_model_supports_seq_input(const struct clip_ctx * ctx) {
|
||||
switch (ctx->proj_type()) {
|
||||
case PROJECTOR_TYPE_QWEN2VL:
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
|
||||
clip_image_f32 clip_img;
|
||||
clip_img.buf.resize(h * w * 3);
|
||||
|
||||
@@ -116,6 +116,3 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel
|
||||
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
|
||||
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
|
||||
bool clip_has_whisper_encoder(const struct clip_ctx * ctx);
|
||||
|
||||
// true if model graph support image->nt (temporal dimension) as input
|
||||
bool clip_model_supports_seq_input(const struct clip_ctx * ctx);
|
||||
|
||||
@@ -26,11 +26,10 @@ struct clip_graph_pixtral : clip_graph {
|
||||
struct clip_graph_qwen2vl : clip_graph {
|
||||
clip_graph_qwen2vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
ggml_tensor * build_inp_with_temporal_merge();
|
||||
};
|
||||
|
||||
struct clip_graph_qwen3vl : clip_graph_qwen2vl {
|
||||
clip_graph_qwen3vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph_qwen2vl(ctx, img) {}
|
||||
struct clip_graph_qwen3vl : clip_graph {
|
||||
clip_graph_qwen3vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,31 +1,5 @@
|
||||
#include "models.h"
|
||||
|
||||
ggml_tensor * clip_graph_qwen2vl::build_inp_with_temporal_merge() {
|
||||
ggml_tensor * inp_raw = build_inp_raw();
|
||||
|
||||
GGML_ASSERT(img.nx % (patch_size * 2) == 0);
|
||||
GGML_ASSERT(img.ny % (patch_size * 2) == 0);
|
||||
|
||||
const size_t nb1 = ggml_row_size(inp_raw->type, img.nx);
|
||||
const size_t nb2 = nb1 * img.ny;
|
||||
|
||||
if (nt == 1) {
|
||||
// still image input
|
||||
return ggml_add(ctx0,
|
||||
ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1),
|
||||
ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1));
|
||||
} else if (nt == 2) {
|
||||
// 2 frames input (video input)
|
||||
ggml_tensor * inp_0 = ggml_view_3d(ctx0, inp_raw, img.nx, img.ny, 3, nb1, nb2, 0);
|
||||
ggml_tensor * inp_1 = ggml_view_3d(ctx0, inp_raw, img.nx, img.ny, 3, nb1, nb2, nb2 * 3);
|
||||
return ggml_add(ctx0,
|
||||
ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_0, patch_size, patch_size, 0, 0, 1, 1),
|
||||
ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_1, patch_size, patch_size, 0, 0, 1, 1));
|
||||
} else {
|
||||
GGML_ASSERT(false && "nt > 2 is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
ggml_cgraph * clip_graph_qwen2vl::build() {
|
||||
GGML_ASSERT(model.patch_bias == nullptr);
|
||||
GGML_ASSERT(model.class_embedding == nullptr);
|
||||
@@ -42,10 +16,17 @@ ggml_cgraph * clip_graph_qwen2vl::build() {
|
||||
|
||||
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||
|
||||
ggml_tensor * inp = build_inp_with_temporal_merge();
|
||||
ggml_tensor * inp_raw = build_inp_raw();
|
||||
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
|
||||
GGML_ASSERT(img.nx % (patch_size * 2) == 0);
|
||||
GGML_ASSERT(img.ny % (patch_size * 2) == 0);
|
||||
|
||||
// second conv dimension
|
||||
{
|
||||
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, inp_1);
|
||||
|
||||
inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
|
||||
inp = ggml_cont_4d(
|
||||
ctx0, inp,
|
||||
|
||||
@@ -13,10 +13,17 @@ ggml_cgraph * clip_graph_qwen3vl::build() {
|
||||
|
||||
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||
|
||||
ggml_tensor * inp = build_inp_with_temporal_merge();
|
||||
ggml_tensor * inp_raw = build_inp_raw();
|
||||
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
|
||||
// spatial merge
|
||||
GGML_ASSERT(img.nx % (patch_size * 2) == 0);
|
||||
GGML_ASSERT(img.ny % (patch_size * 2) == 0);
|
||||
|
||||
// second conv dimension
|
||||
{
|
||||
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, inp_1);
|
||||
|
||||
inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
|
||||
inp = ggml_cont_4d(
|
||||
ctx0, inp,
|
||||
|
||||
@@ -25,11 +25,9 @@
|
||||
|
||||
// represents raw image data, layout is RGBRGBRGB...
|
||||
// length of data must be nx * ny * 3
|
||||
// for sequence of images (i.e. video): data is nt sequential RGB frames, each nx * ny * 3 bytes
|
||||
struct mtmd_bitmap {
|
||||
uint32_t nx;
|
||||
uint32_t ny;
|
||||
uint32_t nt = 1; // 1 for single images, >= 2 (even) for sequence
|
||||
std::vector<unsigned char> data;
|
||||
std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
|
||||
bool is_audio = false; // true if the bitmap is audio
|
||||
@@ -39,8 +37,8 @@ struct mtmd_image_tokens {
|
||||
uint32_t nx; // number of tokens in x direction
|
||||
uint32_t ny; // number of tokens in y direction
|
||||
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
|
||||
uint32_t n_tokens() const { return nx * ny; }
|
||||
clip_image_f32_batch batch_f32; // preprocessed image patches
|
||||
uint32_t n_tokens() const { return nx * ny; } // TODO [QWEN_VIDEO]: we don't count nt here to be compatible with Qwen-VL, but other models in the future might have different logic
|
||||
std::string id; // optional user-defined ID, useful for KV cache tracking
|
||||
|
||||
mtmd_image_tokens clone() {
|
||||
@@ -877,73 +875,6 @@ struct mtmd_tokenizer {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t add_seq_image(const mtmd_bitmap * bitmap) {
|
||||
GGML_ASSERT(ctx->ctx_v);
|
||||
GGML_ASSERT(bitmap->nt > 1);
|
||||
// TODO [QWEN_VIDEO]: we only support even frames (Qwen-VL style) for now
|
||||
GGML_ASSERT(bitmap->nt % 2 == 0);
|
||||
bool support_seq = clip_model_supports_seq_input(ctx->ctx_v);
|
||||
if (!support_seq) {
|
||||
LOG_ERR("%s: error: model does not support sequential image input (usually requires Qwen-VL style models)\n", __func__);
|
||||
return 2;
|
||||
}
|
||||
|
||||
const uint32_t n_frames = bitmap->nt;
|
||||
const size_t frame_bytes = (size_t)bitmap->nx * bitmap->ny * 3;
|
||||
|
||||
// preprocess each frame individually
|
||||
clip_image_f32_batch all_frames;
|
||||
all_frames.is_seq = true;
|
||||
all_frames.grid_x = 0; // currently, we don't support tiling for video input
|
||||
all_frames.grid_y = 0; // currently, we don't support tiling for video input
|
||||
|
||||
for (uint32_t f = 0; f < n_frames; f++) {
|
||||
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||
img_u8->nx = bitmap->nx;
|
||||
img_u8->ny = bitmap->ny;
|
||||
img_u8->buf.resize(frame_bytes);
|
||||
std::memcpy(img_u8->buf.data(), bitmap->data.data() + f * frame_bytes, frame_bytes);
|
||||
|
||||
clip_image_f32_batch frame_batch;
|
||||
bool ok = ctx->image_preproc->preprocess(*img_u8, frame_batch);
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to preprocess image\n");
|
||||
return 2;
|
||||
}
|
||||
GGML_ASSERT(frame_batch.entries.size() == 1);
|
||||
all_frames.entries.push_back(std::move(frame_batch.entries[0]));
|
||||
}
|
||||
|
||||
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
|
||||
if (mtmd_decode_use_mrope(ctx)) {
|
||||
// for Qwen2VL, we need this information for M-RoPE decoding positions
|
||||
image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, all_frames.entries[0].get());
|
||||
image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, all_frames.entries[0].get());
|
||||
image_tokens->use_mrope_pos = true;
|
||||
} else {
|
||||
GGML_ASSERT(false && "not supported");
|
||||
}
|
||||
image_tokens->batch_f32 = std::move(all_frames);
|
||||
image_tokens->id = bitmap->id; // optional
|
||||
|
||||
LOG_DBG("seq_image: nt=%u, nx=%u, ny=%u, n_tokens=%u\n",
|
||||
bitmap->nt, image_tokens->nx, image_tokens->ny, image_tokens->n_tokens());
|
||||
|
||||
mtmd_input_chunk chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||
{}, // text tokens
|
||||
std::move(image_tokens),
|
||||
nullptr, // audio tokens
|
||||
};
|
||||
cur.entries.emplace_back(std::move(chunk));
|
||||
|
||||
if (!ctx->img_end.empty()) {
|
||||
add_text(ctx->img_end, true);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<mtmd_input_chunk> split_batch_to_chunk(clip_image_f32_batch && batch_f32, const std::string & id) {
|
||||
std::vector<mtmd_input_chunk> chunks;
|
||||
|
||||
@@ -1062,7 +993,6 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
|
||||
|| clip_is_glm(ctx_clip)
|
||||
|| proj_type == PROJECTOR_TYPE_INTERNVL) {
|
||||
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
|
||||
// video: each entry is one frame pair, encoded with per-frame attention
|
||||
const auto & entries = image_tokens->batch_f32.entries;
|
||||
for (size_t i = 0; i < entries.size(); i++) {
|
||||
int n_tokens_per_image = clip_n_output_tokens(ctx_clip, entries[i].get());
|
||||
@@ -1145,54 +1075,17 @@ mtmd_bitmap * mtmd_bitmap_init(uint32_t nx,
|
||||
mtmd_bitmap * bitmap = new mtmd_bitmap;
|
||||
bitmap->nx = nx;
|
||||
bitmap->ny = ny;
|
||||
bitmap->nt = 1;
|
||||
size_t data_size = (size_t)nx * ny * 3;
|
||||
bitmap->data.resize(data_size);
|
||||
std::memcpy(bitmap->data.data(), data, data_size);
|
||||
return bitmap;
|
||||
}
|
||||
|
||||
mtmd_bitmap * mtmd_bitmap_init_from_seq(uint32_t nx,
|
||||
uint32_t ny,
|
||||
uint32_t nt,
|
||||
const unsigned char * data) {
|
||||
if (nt == 0) {
|
||||
LOG_ERR("%s: error: nt must be greater than 0 for sequence input\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
if (nt == 1) {
|
||||
// if nt == 1, it's not really a sequence, we can treat it as a single image
|
||||
return mtmd_bitmap_init(nx, ny, data);
|
||||
}
|
||||
// TODO [QWEN_VIDEO]: we only support Qwen-VL style for now, which requires even number of frames
|
||||
// therefore, we duplicate the last frame if nt is odd, to avoid issues in video preprocessing
|
||||
bool is_odd = (nt % 2 == 1);
|
||||
if (is_odd) {
|
||||
nt += 1;
|
||||
}
|
||||
size_t frame_size = (size_t)nx * ny * 3;
|
||||
mtmd_bitmap * bitmap = new mtmd_bitmap;
|
||||
bitmap->nx = nx;
|
||||
bitmap->ny = ny;
|
||||
bitmap->nt = nt;
|
||||
size_t data_size = frame_size * nt;
|
||||
bitmap->data.resize(data_size);
|
||||
std::memcpy(bitmap->data.data(), data, data_size);
|
||||
if (is_odd) {
|
||||
// duplicate the last frame
|
||||
std::memcpy(bitmap->data.data() + (nt - 1) * frame_size,
|
||||
data + (nt - 2) * frame_size,
|
||||
frame_size);
|
||||
}
|
||||
return bitmap;
|
||||
}
|
||||
|
||||
mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples,
|
||||
const float * data) {
|
||||
mtmd_bitmap * bitmap = new mtmd_bitmap;
|
||||
bitmap->nx = n_samples;
|
||||
bitmap->ny = 1;
|
||||
bitmap->nt = 1;
|
||||
bitmap->is_audio = true;
|
||||
size_t data_size = n_samples * sizeof(float);
|
||||
bitmap->data.resize(data_size);
|
||||
@@ -1208,10 +1101,6 @@ uint32_t mtmd_bitmap_get_ny(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->ny;
|
||||
}
|
||||
|
||||
uint32_t mtmd_bitmap_get_nt(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->nt;
|
||||
}
|
||||
|
||||
const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->data.data();
|
||||
}
|
||||
@@ -1224,10 +1113,6 @@ bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->is_audio;
|
||||
}
|
||||
|
||||
bool mtmd_bitmap_is_seq(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->nt >= 2;
|
||||
}
|
||||
|
||||
const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->id.c_str();
|
||||
}
|
||||
@@ -1370,8 +1255,8 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
|
||||
|
||||
llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
|
||||
if (image_tokens->use_mrope_pos) {
|
||||
// for M-RoPE, n_pos = max(t, h, w)
|
||||
// t is omitted as we don't support batching
|
||||
// for M-RoPE, temporal dimension = max(t,h,w)
|
||||
// t is omitted as we don't support video input
|
||||
return std::max(image_tokens->nx, image_tokens->ny);
|
||||
}
|
||||
return image_tokens->n_tokens();
|
||||
|
||||
@@ -135,23 +135,16 @@ MTMD_API int mtmd_get_audio_sample_rate(mtmd_context * ctx);
|
||||
// if bitmap is image:
|
||||
// length of data must be nx * ny * 3
|
||||
// the data is in RGBRGBRGB... format
|
||||
// if bitmap is sequence of images (i.e. video):
|
||||
// nt is the number of frames
|
||||
// length of data must be nx * ny * 3 * nt
|
||||
// frames are sequential RGB, each nx * ny * 3 bytes
|
||||
// if bitmap is audio:
|
||||
// length of data must be n_samples * sizeof(float)
|
||||
// the data is in float format (PCM F32)
|
||||
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx, uint32_t ny, const unsigned char * data);
|
||||
MTMD_API mtmd_bitmap * mtmd_bitmap_init_from_seq (uint32_t nx, uint32_t ny, uint32_t nt, const unsigned char * data);
|
||||
MTMD_API mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples, const float * data);
|
||||
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
|
||||
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
|
||||
MTMD_API uint32_t mtmd_bitmap_get_nt (const mtmd_bitmap * bitmap);
|
||||
MTMD_API const unsigned char * mtmd_bitmap_get_data (const mtmd_bitmap * bitmap);
|
||||
MTMD_API size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap);
|
||||
MTMD_API bool mtmd_bitmap_is_audio (const mtmd_bitmap * bitmap);
|
||||
MTMD_API bool mtmd_bitmap_is_seq (const mtmd_bitmap * bitmap);
|
||||
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
|
||||
// bitmap ID is optional, but useful for KV cache tracking
|
||||
// these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data()
|
||||
@@ -284,14 +277,9 @@ struct bitmap {
|
||||
bitmap(uint32_t nx, uint32_t ny, const unsigned char * data) {
|
||||
ptr.reset(mtmd_bitmap_init(nx, ny, data));
|
||||
}
|
||||
bitmap(uint32_t nx, uint32_t ny, uint32_t nt, const unsigned char * data) {
|
||||
ptr.reset(mtmd_bitmap_init_from_seq(nx, ny, nt, data));
|
||||
}
|
||||
~bitmap() = default;
|
||||
uint32_t nx() const { return mtmd_bitmap_get_nx(ptr.get()); }
|
||||
uint32_t ny() const { return mtmd_bitmap_get_ny(ptr.get()); }
|
||||
uint32_t nt() const { return mtmd_bitmap_get_nt(ptr.get()); }
|
||||
bool is_seq() const { return mtmd_bitmap_is_seq(ptr.get()); }
|
||||
uint32_t nx() const { return mtmd_bitmap_get_nx(ptr.get()); }
|
||||
uint32_t ny() const { return mtmd_bitmap_get_ny(ptr.get()); }
|
||||
const unsigned char * data() const { return mtmd_bitmap_get_data(ptr.get()); }
|
||||
size_t n_bytes() const { return mtmd_bitmap_get_n_bytes(ptr.get()); }
|
||||
std::string id() const { return mtmd_bitmap_get_id(ptr.get()); }
|
||||
|
||||
@@ -98,6 +98,7 @@ add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
|
||||
add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M"
|
||||
add_test_audio "ggml-org/LFM2-Audio-1.5B-GGUF:Q8_0"
|
||||
add_test_audio "ggml-org/gemma-4-E2B-it-GGUF:Q8_0" --jinja
|
||||
add_test_audio "ggml-org/Qwen3-ASR-0.6B-GGUF:Q8_0"
|
||||
|
||||
# to test the big models, run: ./tests.sh big
|
||||
if [ "$RUN_BIG_TESTS" = true ]; then
|
||||
|
||||
@@ -1433,6 +1433,60 @@ json convert_responses_to_chatcmpl(const json & response_body) {
|
||||
return chatcmpl_body;
|
||||
}
|
||||
|
||||
json convert_transcriptions_to_chatcmpl(
|
||||
const json & inp_body,
|
||||
const std::map<std::string, raw_buffer> & in_files,
|
||||
std::vector<raw_buffer> & out_files) {
|
||||
// TODO @ngxson : this function may need to be improved in the future
|
||||
// handle input files
|
||||
out_files.clear();
|
||||
auto it = in_files.find("file");
|
||||
if (it != in_files.end()) {
|
||||
out_files.push_back(it->second);
|
||||
} else {
|
||||
throw std::invalid_argument("No input file found for transcription");
|
||||
}
|
||||
|
||||
// handle input data
|
||||
std::string prompt = json_value(inp_body, "prompt", std::string());
|
||||
std::string language = json_value(inp_body, "language", std::string());
|
||||
std::string response_format = json_value(inp_body, "response_format", std::string("json"));
|
||||
if (response_format != "json") {
|
||||
throw std::invalid_argument("Only 'json' response_format is supported for transcription");
|
||||
}
|
||||
if (prompt.empty()) {
|
||||
prompt = "Transcribe audio to text";
|
||||
}
|
||||
if (!language.empty()) {
|
||||
prompt += string_format(" (language: %s)", language.c_str());
|
||||
}
|
||||
prompt += mtmd_default_marker();
|
||||
|
||||
json chatcmpl_body = inp_body; // copy all fields
|
||||
chatcmpl_body["messages"] = json::array({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", prompt},
|
||||
},
|
||||
});
|
||||
|
||||
// because input from form-data, everything is string, we need to correct the types here
|
||||
std::string stream = json_value(inp_body, "stream", std::string("false"));
|
||||
chatcmpl_body["stream"] = stream == "true";
|
||||
|
||||
if (inp_body.contains("max_tokens")) {
|
||||
std::string inp = inp_body["max_tokens"].get<std::string>();
|
||||
chatcmpl_body["max_tokens"] = std::stoul(inp);
|
||||
}
|
||||
|
||||
if (inp_body.contains("temperature")) {
|
||||
std::string inp = inp_body["temperature"].get<std::string>();
|
||||
chatcmpl_body["temperature"] = std::stof(inp);
|
||||
}
|
||||
|
||||
return chatcmpl_body;
|
||||
}
|
||||
|
||||
json convert_anthropic_to_oai(const json & body) {
|
||||
json oai_body;
|
||||
|
||||
|
||||
@@ -305,6 +305,12 @@ json oaicompat_chat_params_parse(
|
||||
// convert OpenAI Responses API format to OpenAI Chat Completions API format
|
||||
json convert_responses_to_chatcmpl(const json & body);
|
||||
|
||||
// convert OpenAI transcriptions API format to OpenAI Chat Completions API format
|
||||
json convert_transcriptions_to_chatcmpl(
|
||||
const json & body,
|
||||
const std::map<std::string, raw_buffer> & in_files,
|
||||
std::vector<raw_buffer> & out_files);
|
||||
|
||||
// convert Anthropic Messages API format to OpenAI Chat Completions API format
|
||||
json convert_anthropic_to_oai(const json & body);
|
||||
|
||||
|
||||
@@ -3732,6 +3732,33 @@ void server_routes::init_routes() {
|
||||
TASK_RESPONSE_TYPE_OAI_RESP);
|
||||
};
|
||||
|
||||
this->post_transcriptions_oai = [this](const server_http_req & req) {
|
||||
auto res = create_response();
|
||||
|
||||
if (!meta->has_mtmd || !meta->chat_params.allow_audio) {
|
||||
res->error(format_error_response("The current model does not support audio input.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<raw_buffer> files;
|
||||
json body = convert_transcriptions_to_chatcmpl(
|
||||
json::parse(req.body),
|
||||
req.files,
|
||||
files);
|
||||
SRV_DBG("%s\n", "Request converted: OpenAI Transcriptions -> OpenAI Chat Completions");
|
||||
SRV_DBG("converted request: %s\n", body.dump().c_str());
|
||||
json body_parsed = oaicompat_chat_params_parse(
|
||||
body,
|
||||
meta->chat_params,
|
||||
files);
|
||||
return handle_completions_impl(
|
||||
req,
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
body_parsed,
|
||||
files,
|
||||
TASK_RESPONSE_TYPE_OAI_ASR);
|
||||
};
|
||||
|
||||
this->post_anthropic_messages = [this](const server_http_req & req) {
|
||||
auto res = create_response();
|
||||
std::vector<raw_buffer> files;
|
||||
|
||||
@@ -111,6 +111,7 @@ struct server_routes {
|
||||
server_http_context::handler_t post_completions_oai;
|
||||
server_http_context::handler_t post_chat_completions;
|
||||
server_http_context::handler_t post_responses_oai;
|
||||
server_http_context::handler_t post_transcriptions_oai;
|
||||
server_http_context::handler_t post_anthropic_messages;
|
||||
server_http_context::handler_t post_anthropic_count_tokens;
|
||||
server_http_context::handler_t post_apply_template;
|
||||
|
||||
@@ -428,6 +428,7 @@ void server_http_context::get(const std::string & path, const server_http_contex
|
||||
req.path,
|
||||
build_query_string(req),
|
||||
req.body,
|
||||
{},
|
||||
req.is_connection_closed
|
||||
});
|
||||
server_http_res_ptr response = handler(*request);
|
||||
@@ -437,12 +438,39 @@ void server_http_context::get(const std::string & path, const server_http_contex
|
||||
|
||||
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
|
||||
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
||||
std::string body = req.body;
|
||||
std::map<std::string, raw_buffer> files;
|
||||
|
||||
if (req.is_multipart_form_data()) {
|
||||
// translate text fields to a JSON object and use it as the body
|
||||
json form_json = json::object();
|
||||
for (const auto & [key, field] : req.form.fields) {
|
||||
if (form_json.contains(key)) {
|
||||
// if the key already exists, convert it to an array
|
||||
if (!form_json[key].is_array()) {
|
||||
json existing_value = form_json[key];
|
||||
form_json[key] = json::array({existing_value});
|
||||
}
|
||||
form_json[key].push_back(field.content);
|
||||
} else {
|
||||
form_json[key] = field.content;
|
||||
}
|
||||
}
|
||||
body = form_json.dump();
|
||||
|
||||
// populate files from multipart form
|
||||
for (const auto & [key, file] : req.form.files) {
|
||||
files[key] = raw_buffer(file.content.begin(), file.content.end());
|
||||
}
|
||||
}
|
||||
|
||||
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
|
||||
get_params(req),
|
||||
get_headers(req),
|
||||
req.path,
|
||||
build_query_string(req),
|
||||
req.body,
|
||||
body,
|
||||
std::move(files),
|
||||
req.is_connection_closed
|
||||
});
|
||||
server_http_res_ptr response = handler(*request);
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
struct common_params;
|
||||
|
||||
@@ -32,6 +34,7 @@ struct server_http_res {
|
||||
// unique pointer, used by set_chunked_content_provider
|
||||
// httplib requires the stream provider to be stored in heap
|
||||
using server_http_res_ptr = std::unique_ptr<server_http_res>;
|
||||
using raw_buffer = std::vector<uint8_t>;
|
||||
|
||||
struct server_http_req {
|
||||
std::map<std::string, std::string> params; // path_params + query_params
|
||||
@@ -39,6 +42,7 @@ struct server_http_req {
|
||||
std::string path;
|
||||
std::string query_string; // query parameters string (e.g. "action=save")
|
||||
std::string body;
|
||||
std::map<std::string, raw_buffer> files; // used for file uploads (form data)
|
||||
const std::function<bool()> & should_stop;
|
||||
|
||||
std::string get_param(const std::string & key, const std::string & def = "") const {
|
||||
|
||||
@@ -725,6 +725,8 @@ json server_task_result_cmpl_final::to_json() {
|
||||
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
|
||||
case TASK_RESPONSE_TYPE_OAI_RESP:
|
||||
return stream ? to_json_oaicompat_resp_stream() : to_json_oaicompat_resp();
|
||||
case TASK_RESPONSE_TYPE_OAI_ASR:
|
||||
return to_json_oaicompat_asr();
|
||||
case TASK_RESPONSE_TYPE_ANTHROPIC:
|
||||
return stream ? to_json_anthropic_stream() : to_json_anthropic();
|
||||
default:
|
||||
@@ -1102,6 +1104,21 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() {
|
||||
return server_sent_events;
|
||||
}
|
||||
|
||||
json server_task_result_cmpl_final::to_json_oaicompat_asr() {
|
||||
json event = json {
|
||||
{"type", "transcript.text.done"},
|
||||
{"text", content},
|
||||
{"usage", json {
|
||||
{"type", "tokens"},
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"output_tokens", n_decoded},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
{"input_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }},
|
||||
}},
|
||||
};
|
||||
return event;
|
||||
}
|
||||
|
||||
json server_task_result_cmpl_final::to_json_anthropic() {
|
||||
std::string stop_reason = "max_tokens";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
@@ -1400,6 +1417,8 @@ json server_task_result_cmpl_partial::to_json() {
|
||||
return to_json_oaicompat_chat();
|
||||
case TASK_RESPONSE_TYPE_OAI_RESP:
|
||||
return to_json_oaicompat_resp();
|
||||
case TASK_RESPONSE_TYPE_OAI_ASR:
|
||||
return to_json_oaicompat_asr();
|
||||
case TASK_RESPONSE_TYPE_ANTHROPIC:
|
||||
return to_json_anthropic();
|
||||
default:
|
||||
@@ -1650,6 +1669,14 @@ json server_task_result_cmpl_partial::to_json_oaicompat_resp() {
|
||||
return events;
|
||||
}
|
||||
|
||||
json server_task_result_cmpl_partial::to_json_oaicompat_asr() {
|
||||
json event = json {
|
||||
{"type", "transcript.text.delta"},
|
||||
{"delta", content},
|
||||
};
|
||||
return event;
|
||||
}
|
||||
|
||||
json server_task_result_cmpl_partial::to_json_anthropic() {
|
||||
json events = json::array();
|
||||
bool first = (n_decoded == 1);
|
||||
|
||||
@@ -34,6 +34,7 @@ enum task_response_type {
|
||||
TASK_RESPONSE_TYPE_OAI_CHAT,
|
||||
TASK_RESPONSE_TYPE_OAI_CMPL,
|
||||
TASK_RESPONSE_TYPE_OAI_RESP,
|
||||
TASK_RESPONSE_TYPE_OAI_ASR, // transcriptions API
|
||||
TASK_RESPONSE_TYPE_OAI_EMBD,
|
||||
TASK_RESPONSE_TYPE_ANTHROPIC,
|
||||
};
|
||||
@@ -401,6 +402,8 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
|
||||
json to_json_oaicompat_resp_stream();
|
||||
|
||||
json to_json_oaicompat_asr();
|
||||
|
||||
json to_json_anthropic();
|
||||
|
||||
json to_json_anthropic_stream();
|
||||
@@ -457,6 +460,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
|
||||
json to_json_oaicompat_resp();
|
||||
|
||||
json to_json_oaicompat_asr();
|
||||
|
||||
json to_json_anthropic();
|
||||
};
|
||||
|
||||
|
||||
@@ -145,6 +145,7 @@ int main(int argc, char ** argv) {
|
||||
routes.post_completions_oai = models_routes->proxy_post;
|
||||
routes.post_chat_completions = models_routes->proxy_post;
|
||||
routes.post_responses_oai = models_routes->proxy_post;
|
||||
routes.post_transcriptions_oai = models_routes->proxy_post;
|
||||
routes.post_anthropic_messages = models_routes->proxy_post;
|
||||
routes.post_anthropic_count_tokens = models_routes->proxy_post;
|
||||
routes.post_infill = models_routes->proxy_post;
|
||||
@@ -160,48 +161,51 @@ int main(int argc, char ** argv) {
|
||||
routes.post_slots = models_routes->proxy_post;
|
||||
|
||||
// custom routes for router
|
||||
routes.get_props = models_routes->get_router_props;
|
||||
routes.get_models = models_routes->get_router_models;
|
||||
ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load));
|
||||
ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload));
|
||||
routes.get_props = models_routes->get_router_props;
|
||||
routes.get_models = models_routes->get_router_models;
|
||||
|
||||
ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load));
|
||||
ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload));
|
||||
}
|
||||
|
||||
ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
|
||||
ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
|
||||
ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics));
|
||||
ctx_http.get ("/props", ex_wrapper(routes.get_props));
|
||||
ctx_http.post("/props", ex_wrapper(routes.post_props));
|
||||
ctx_http.post("/api/show", ex_wrapper(routes.get_api_show));
|
||||
ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
|
||||
ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
|
||||
ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check)
|
||||
ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy
|
||||
ctx_http.post("/completions", ex_wrapper(routes.post_completions));
|
||||
ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai));
|
||||
ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions));
|
||||
ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
|
||||
ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint
|
||||
ctx_http.post("/v1/responses", ex_wrapper(routes.post_responses_oai));
|
||||
ctx_http.post("/responses", ex_wrapper(routes.post_responses_oai));
|
||||
ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API
|
||||
ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
|
||||
ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
|
||||
ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics));
|
||||
ctx_http.get ("/props", ex_wrapper(routes.get_props));
|
||||
ctx_http.post("/props", ex_wrapper(routes.post_props));
|
||||
ctx_http.post("/api/show", ex_wrapper(routes.get_api_show));
|
||||
ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
|
||||
ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
|
||||
ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check)
|
||||
ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy
|
||||
ctx_http.post("/completions", ex_wrapper(routes.post_completions));
|
||||
ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai));
|
||||
ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions));
|
||||
ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
|
||||
ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint
|
||||
ctx_http.post("/v1/responses", ex_wrapper(routes.post_responses_oai));
|
||||
ctx_http.post("/responses", ex_wrapper(routes.post_responses_oai));
|
||||
ctx_http.post("/v1/audio/transcriptions", ex_wrapper(routes.post_transcriptions_oai));
|
||||
ctx_http.post("/audio/transcriptions", ex_wrapper(routes.post_transcriptions_oai));
|
||||
ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API
|
||||
ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting
|
||||
ctx_http.post("/infill", ex_wrapper(routes.post_infill));
|
||||
ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy
|
||||
ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings));
|
||||
ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai));
|
||||
ctx_http.post("/rerank", ex_wrapper(routes.post_rerank));
|
||||
ctx_http.post("/reranking", ex_wrapper(routes.post_rerank));
|
||||
ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank));
|
||||
ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank));
|
||||
ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize));
|
||||
ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize));
|
||||
ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template));
|
||||
ctx_http.post("/infill", ex_wrapper(routes.post_infill));
|
||||
ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy
|
||||
ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings));
|
||||
ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai));
|
||||
ctx_http.post("/rerank", ex_wrapper(routes.post_rerank));
|
||||
ctx_http.post("/reranking", ex_wrapper(routes.post_rerank));
|
||||
ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank));
|
||||
ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank));
|
||||
ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize));
|
||||
ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize));
|
||||
ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template));
|
||||
// LoRA adapters hotswap
|
||||
ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters));
|
||||
ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters));
|
||||
ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters));
|
||||
ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters));
|
||||
// Save & load slots
|
||||
ctx_http.get ("/slots", ex_wrapper(routes.get_slots));
|
||||
ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots));
|
||||
ctx_http.get ("/slots", ex_wrapper(routes.get_slots));
|
||||
ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots));
|
||||
// CORS proxy (EXPERIMENTAL, only used by the Web UI for MCP)
|
||||
if (params.webui_mcp_proxy) {
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
|
||||
Reference in New Issue
Block a user