mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-16 16:27:32 +03:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e21cdc11a0 | ||
|
|
e974923698 | ||
|
|
1c0d9081fd | ||
|
|
a8bad3842e | ||
|
|
75f3bc94e6 | ||
|
|
aa00911d12 | ||
|
|
ce8fd4b1a6 | ||
|
|
9f5e1edb10 | ||
|
|
920b3e78cb | ||
|
|
974c8c94cc | ||
|
|
227ed28e12 | ||
|
|
bafae27654 | ||
|
|
873c825611 | ||
|
|
82764d8f40 | ||
|
|
21a4933042 | ||
|
|
1e9d771e2c | ||
|
|
aa4695c5e5 | ||
|
|
547765a93e | ||
|
|
9e209c5aee | ||
|
|
6313acbef0 |
2
.github/workflows/close-issue.yml
vendored
2
.github/workflows/close-issue.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap"
|
||||
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap,security"
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
|
||||
197
common/chat.cpp
197
common/chat.cpp
@@ -1091,6 +1091,14 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
|
||||
if (inputs.add_generation_prompt && string_ends_with(data.prompt, "<turn|>\n")) {
|
||||
// This may happen if the model generates content + tool_call, the
|
||||
// template does not add the model's next turn and confuses the model
|
||||
// from emitting its proper reasoning token sequence.
|
||||
data.prompt += "<|turn>model\n";
|
||||
}
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<|channel>thought";
|
||||
@@ -1118,7 +1126,8 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
p.rule("thought", p.content(p.literal("<|channel>thought") + p.space() + p.until("<channel|>") + p.literal("<channel|>")));
|
||||
}
|
||||
|
||||
auto thought = (p.peek(p.literal("<|channel>")) + p.ref("thought")) | p.negate(p.literal("<|channel>"));
|
||||
auto consume_empty_channels = p.gbnf(p.zero_or_more(p.literal("<|channel>") + p.negate(p.literal("thought"))), "");
|
||||
auto thought = (p.peek(p.literal("<|channel>")) + consume_empty_channels + p.ref("thought")) | p.negate(p.literal("<|channel>"));
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.literal("```json") <<
|
||||
@@ -1182,12 +1191,16 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
/* max = */ inputs.parallel_tool_calls ? -1 : 1
|
||||
));
|
||||
|
||||
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<|tool_call>"})));
|
||||
auto scan_to_toolcall = p.rule("scan-to-toolcall", p.until("<|tool_call>"));
|
||||
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<channel|>", "<|tool_call>"})));
|
||||
auto message = p.rule("message", thought + content);
|
||||
return start + p.zero_or_more(message) + tool_call;
|
||||
return start + p.zero_or_more(message) + scan_to_toolcall + tool_call;
|
||||
}
|
||||
|
||||
auto content = p.rule("content", p.content(p.until("<|channel>")));
|
||||
// Gemma 4 may emit an extra <|channel>thought\n<channel|> at the end of the content. It may
|
||||
// also emit a single trailing <channel|> token. Consume all complete reasoning blocks and
|
||||
// then stop at the first unmatched <channel|> token.
|
||||
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<channel|>"})));
|
||||
auto message = p.rule("message", thought + content);
|
||||
return start + p.one_or_more(message);
|
||||
});
|
||||
@@ -1656,6 +1669,173 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_deepseek_v3_2(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<think>";
|
||||
data.thinking_end_tag = "</think>";
|
||||
data.preserved_tokens = {
|
||||
"|DSML|",
|
||||
"<think>",
|
||||
"</think>",
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
|
||||
const std::string DSML = "|DSML|";
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
const std::string FC_START = "<" + DSML + "function_calls>";
|
||||
const std::string FC_END = "</" + DSML + "function_calls>";
|
||||
const std::string INVOKE_START = "<" + DSML + "invoke";
|
||||
const std::string INVOKE_END = "</" + DSML + "invoke>";
|
||||
const std::string PARAM_START = "<" + DSML + "parameter";
|
||||
const std::string PARAM_END = "</" + DSML + "parameter>";
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
if (extract_reasoning && inputs.enable_thinking) {
|
||||
reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END);
|
||||
} else if (extract_reasoning) {
|
||||
// Thinking disabled but reasoning extraction requested: the generation prompt
|
||||
// contains an empty <think></think> pair that must still be consumed.
|
||||
reasoning = p.optional(p.literal(THINK_START) + p.until(THINK_END) + p.literal(THINK_END));
|
||||
}
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.rule("response-format",
|
||||
p.literal("```json") + p.space() +
|
||||
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)) +
|
||||
p.space() + p.literal("```"));
|
||||
return generation_prompt + reasoning + response_format + end;
|
||||
}
|
||||
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return generation_prompt + reasoning + p.content(p.rest()) + end;
|
||||
}
|
||||
|
||||
auto tool_choice = p.choice();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto params = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
const auto & props = params.contains("properties") ? params.at("properties") : json::object();
|
||||
|
||||
std::set<std::string> required;
|
||||
if (params.contains("required")) {
|
||||
params.at("required").get_to(required);
|
||||
}
|
||||
|
||||
auto schema_info = common_schema_info();
|
||||
schema_info.resolve_refs(params);
|
||||
|
||||
std::vector<common_peg_parser> required_parsers;
|
||||
std::vector<common_peg_parser> optional_parsers;
|
||||
for (const auto & [param_name, param_schema] : props.items()) {
|
||||
bool is_required = required.find(param_name) != required.end();
|
||||
bool is_string = schema_info.resolves_to_string(param_schema);
|
||||
|
||||
auto arg = p.tool_arg(
|
||||
p.tool_arg_open(
|
||||
p.literal(PARAM_START + " name=\"") +
|
||||
p.tool_arg_name(p.literal(param_name)) +
|
||||
p.literal("\" string=\"" + std::string(is_string ? "true" : "false") + "\">")) +
|
||||
(is_string
|
||||
? p.tool_arg_string_value(p.until(PARAM_END))
|
||||
: p.tool_arg_json_value(p.schema(p.json(),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema, false))) +
|
||||
p.tool_arg_close(p.literal(PARAM_END)));
|
||||
|
||||
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
|
||||
if (is_required) {
|
||||
required_parsers.push_back(named_arg);
|
||||
} else {
|
||||
optional_parsers.push_back(named_arg);
|
||||
}
|
||||
}
|
||||
|
||||
common_peg_parser args_seq = p.eps();
|
||||
for (size_t i = 0; i < required_parsers.size(); i++) {
|
||||
if (i > 0) {
|
||||
args_seq = args_seq + p.space();
|
||||
}
|
||||
args_seq = args_seq + required_parsers[i];
|
||||
}
|
||||
|
||||
if (!optional_parsers.empty()) {
|
||||
common_peg_parser any_opt = p.choice();
|
||||
for (const auto & opt : optional_parsers) {
|
||||
any_opt |= opt;
|
||||
}
|
||||
args_seq = args_seq + p.repeat(p.space() + any_opt, 0, -1);
|
||||
}
|
||||
|
||||
common_peg_parser invoke_body = args_seq;
|
||||
auto func_parser = p.tool(
|
||||
p.tool_open(p.literal(INVOKE_START + " name=\"") +
|
||||
p.tool_name(p.literal(name)) + p.literal("\">\n")) +
|
||||
invoke_body + p.space() +
|
||||
p.tool_close(p.literal(INVOKE_END)));
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
auto require_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
common_peg_parser tool_calls = p.eps();
|
||||
if (inputs.parallel_tool_calls) {
|
||||
tool_calls = p.trigger_rule("tool-call",
|
||||
p.literal(FC_START) + p.space() + tool_choice +
|
||||
p.zero_or_more(p.space() + tool_choice) + p.space() + p.literal(FC_END));
|
||||
} else {
|
||||
tool_calls = p.trigger_rule("tool-call",
|
||||
p.literal(FC_START) + p.space() + tool_choice + p.space() + p.literal(FC_END));
|
||||
}
|
||||
|
||||
if (!require_tools) {
|
||||
tool_calls = p.optional(tool_calls);
|
||||
}
|
||||
|
||||
auto content_before_tools = p.content(p.until(FC_START));
|
||||
return generation_prompt + reasoning + content_before_tools + tool_calls + end;
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, FC_START },
|
||||
};
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
namespace workaround {
|
||||
|
||||
static void map_developer_role_to_system(json & messages) {
|
||||
@@ -1927,6 +2107,15 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
return common_chat_params_init_gigachat_v3(tmpl, params);
|
||||
}
|
||||
|
||||
// DeepSeek V3.2 format detection: template defines dsml_token and uses it for tool calls.
|
||||
// The template source contains the token as a variable assignment, not as a literal in markup.
|
||||
if (src.find("dsml_token") != std::string::npos &&
|
||||
src.find("function_calls") != std::string::npos &&
|
||||
src.find("DSML") != std::string::npos) {
|
||||
LOG_DBG("Using specialized template: DeepSeek V3.2\n");
|
||||
return common_chat_params_init_deepseek_v3_2(tmpl, params);
|
||||
}
|
||||
|
||||
// Gemma4 format detection
|
||||
if (src.find("'<|tool_call>call:'") != std::string::npos) {
|
||||
if (src.find("{#- OpenAI Chat Completions:") == std::string::npos) {
|
||||
|
||||
@@ -258,6 +258,9 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
if (progress_step >= p.total / 1000 || p.downloaded == p.total) {
|
||||
if (callback) {
|
||||
callback->on_update(p);
|
||||
if (callback->is_cancelled()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
progress_step = 0;
|
||||
}
|
||||
@@ -373,6 +376,9 @@ static int common_download_file_single_online(const std::string & url,
|
||||
}
|
||||
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
if (opts.callback && opts.callback->is_cancelled()) {
|
||||
break;
|
||||
}
|
||||
if (i) {
|
||||
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(delay));
|
||||
@@ -412,6 +418,12 @@ static int common_download_file_single_online(const std::string & url,
|
||||
if (opts.callback) {
|
||||
opts.callback->on_done(p, success);
|
||||
}
|
||||
if (opts.callback && opts.callback->is_cancelled() &&
|
||||
std::filesystem::exists(path_temporary)) {
|
||||
if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, path_temporary.c_str());
|
||||
}
|
||||
}
|
||||
if (!success) {
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
return -1; // max attempts reached
|
||||
|
||||
@@ -21,6 +21,7 @@ public:
|
||||
virtual void on_start(const common_download_progress & p) = 0;
|
||||
virtual void on_update(const common_download_progress & p) = 0;
|
||||
virtual void on_done(const common_download_progress & p, bool ok) = 0;
|
||||
virtual bool is_cancelled() const { return false; }
|
||||
};
|
||||
|
||||
struct common_remote_params {
|
||||
|
||||
@@ -890,6 +890,10 @@ struct parser_executor {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
|
||||
return arena.parse(p.child, ctx, start_pos);
|
||||
}
|
||||
};
|
||||
|
||||
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
|
||||
@@ -957,7 +961,8 @@ void common_peg_arena::resolve_refs() {
|
||||
std::is_same_v<T, common_peg_and_parser> ||
|
||||
std::is_same_v<T, common_peg_not_parser> ||
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser>) {
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
@@ -1036,6 +1041,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
|
||||
return "Not(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
|
||||
return "Atomic(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
|
||||
return "Any";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
|
||||
@@ -1565,6 +1572,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
|
||||
std::is_same_v<T, common_peg_not_parser> ||
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser> ||
|
||||
std::is_same_v<T, common_peg_schema_parser>) {
|
||||
visit(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
@@ -1651,10 +1659,13 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
} else if constexpr (std::is_same_v<T, common_peg_sequence_parser>) {
|
||||
std::string s;
|
||||
for (const auto & child : p.children) {
|
||||
auto child_gbnf = to_gbnf(child);
|
||||
if (child_gbnf.empty()) {
|
||||
continue;
|
||||
}
|
||||
if (!s.empty()) {
|
||||
s += " ";
|
||||
}
|
||||
auto child_gbnf = to_gbnf(child);
|
||||
const auto & child_parser = effective_parser(child);
|
||||
if (std::holds_alternative<common_peg_choice_parser>(child_parser) ||
|
||||
std::holds_alternative<common_peg_sequence_parser>(child_parser)) {
|
||||
@@ -1754,6 +1765,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
return to_gbnf(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
|
||||
return to_gbnf(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return p.grammar;
|
||||
} else {
|
||||
static_assert(is_always_false_v<T>);
|
||||
}
|
||||
@@ -1888,6 +1901,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
|
||||
{"child", p.child},
|
||||
{"tag", p.tag}
|
||||
};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
|
||||
}
|
||||
}, variant);
|
||||
}
|
||||
@@ -2050,6 +2065,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
|
||||
};
|
||||
}
|
||||
|
||||
if (type == "gbnf") {
|
||||
if (!j.contains("child") || !j.contains("grammar")) {
|
||||
throw std::runtime_error("gbnf parser missing required fields");
|
||||
}
|
||||
return common_peg_gbnf_parser{
|
||||
j["child"].get<common_peg_parser_id>(),
|
||||
j["grammar"].get<std::string>(),
|
||||
};
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unknown parser type: " + type);
|
||||
}
|
||||
|
||||
|
||||
@@ -270,6 +270,11 @@ struct common_peg_tag_parser {
|
||||
std::string tag;
|
||||
};
|
||||
|
||||
struct common_peg_gbnf_parser {
|
||||
common_peg_parser_id child;
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
// Variant holding all parser types
|
||||
using common_peg_parser_variant = std::variant<
|
||||
common_peg_epsilon_parser,
|
||||
@@ -290,7 +295,8 @@ using common_peg_parser_variant = std::variant<
|
||||
common_peg_rule_parser,
|
||||
common_peg_ref_parser,
|
||||
common_peg_atomic_parser,
|
||||
common_peg_tag_parser
|
||||
common_peg_tag_parser,
|
||||
common_peg_gbnf_parser
|
||||
>;
|
||||
|
||||
class common_peg_arena {
|
||||
@@ -504,6 +510,10 @@ class common_peg_parser_builder {
|
||||
// Unlike rules, you can tag multiple nodes with the same tag.
|
||||
common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); }
|
||||
|
||||
// Wraps a child parser but emits a custom GBNF grammar string instead of
|
||||
// the child's grammar. Parsing delegates entirely to the child.
|
||||
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
|
||||
|
||||
void set_root(const common_peg_parser & p);
|
||||
|
||||
common_peg_arena build();
|
||||
|
||||
@@ -4258,9 +4258,7 @@ class Qwen2VLVisionModel(MmprojModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen2_5OmniModel")
|
||||
class Qwen25OmniModel(Qwen2VLVisionModel):
|
||||
has_vision_encoder = True
|
||||
class Qwen25AudioModel(MmprojModel):
|
||||
has_audio_encoder = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -4276,12 +4274,6 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
|
||||
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"])
|
||||
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5))
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
return self.global_config["thinker_config"].get("vision_config")
|
||||
|
||||
def get_audio_config(self) -> dict[str, Any] | None:
|
||||
return self.global_config["thinker_config"].get("audio_config")
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
# SinusoidsPositionEmbedding
|
||||
assert self.hparams_audio is not None
|
||||
@@ -4312,7 +4304,32 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
|
||||
# this tensor is left unused in transformers code
|
||||
# https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py#L1809
|
||||
return
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
yield from MmprojModel.modify_tensors(self, data_torch, name, bid)
|
||||
|
||||
return # skip other tensors
|
||||
|
||||
|
||||
@ModelBase.register("Qwen2_5OmniModel")
|
||||
class Qwen25OmniModel(Qwen2VLVisionModel, Qwen25AudioModel):
|
||||
has_audio_encoder = True
|
||||
has_vision_encoder = True
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
return self.global_config["thinker_config"].get("vision_config")
|
||||
|
||||
def get_audio_config(self) -> dict[str, Any] | None:
|
||||
return self.global_config["thinker_config"].get("audio_config")
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "visual." in name:
|
||||
yield from Qwen2VLVisionModel.modify_tensors(self, data_torch, name, bid)
|
||||
elif "audio_tower." in name:
|
||||
yield from Qwen25AudioModel.modify_tensors(self, data_torch, name, bid)
|
||||
return # skip other tensors
|
||||
|
||||
|
||||
@ModelBase.register("InternVisionModel")
|
||||
@@ -4816,7 +4833,10 @@ class RND1Model(Qwen2MoeModel):
|
||||
class Qwen3VLVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
if self.hparams_vision is None:
|
||||
logger.info("No vision config found, skipping vision tensor processing")
|
||||
return
|
||||
|
||||
# Compute image_size if not present
|
||||
if "image_size" not in self.hparams_vision:
|
||||
# For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
|
||||
@@ -4837,7 +4857,9 @@ class Qwen3VLVisionModel(MmprojModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
|
||||
# in case mixed modalities, the arch will be handled by subclass
|
||||
if not self.has_audio_encoder:
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
|
||||
if self.hparams_vision is not None:
|
||||
@@ -4925,11 +4947,64 @@ class Qwen3VLVisionModel(MmprojModel):
|
||||
return
|
||||
|
||||
if name.startswith("visual."):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
return
|
||||
yield from MmprojModel.modify_tensors(self, data_torch, name, bid)
|
||||
return # skip other tensors
|
||||
|
||||
# Fall back to parent class for other tensors
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
@ModelBase.register("Qwen3OmniMoeForConditionalGeneration")
|
||||
class Qwen3OmniMmprojModel(Qwen3VLVisionModel, Qwen25AudioModel):
|
||||
has_audio_encoder = True
|
||||
has_vision_encoder = True
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
if self.has_vision_encoder:
|
||||
return self.global_config["thinker_config"].get("vision_config")
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_audio_config(self) -> dict[str, Any] | None:
|
||||
if self.has_audio_encoder:
|
||||
return self.global_config["thinker_config"].get("audio_config")
|
||||
else:
|
||||
return None
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
if self.has_vision_encoder:
|
||||
Qwen3VLVisionModel.set_gguf_parameters(self)
|
||||
self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.QWEN3VL)
|
||||
if self.has_audio_encoder:
|
||||
Qwen25AudioModel.set_gguf_parameters(self)
|
||||
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.QWEN3A)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "visual." in name:
|
||||
if not self.has_vision_encoder:
|
||||
raise ValueError(f"Model does not have vision encoder, but found tensor {name}")
|
||||
# need to transform vision tensor naming, so that modify_tensors() logic can be used correctly
|
||||
name = name.replace("thinker.visual.", "model.visual.")
|
||||
if ".merger_list." in name:
|
||||
name = name.replace(".merger_list.", ".deepstack_merger_list.")
|
||||
name = name.replace(".ln_q", ".norm")
|
||||
name = name.replace(".mlp.0", ".linear_fc1")
|
||||
name = name.replace(".mlp.2", ".linear_fc2")
|
||||
elif ".merger." in name:
|
||||
name = name.replace(".ln_q", ".norm")
|
||||
name = name.replace(".mlp.0", ".linear_fc1")
|
||||
name = name.replace(".mlp.2", ".linear_fc2")
|
||||
yield from Qwen3VLVisionModel.modify_tensors(self, data_torch, name, bid)
|
||||
elif "audio_tower." in name:
|
||||
if not self.has_audio_encoder:
|
||||
raise ValueError(f"Model does not have audio encoder, but found tensor {name}")
|
||||
if "conv2d" in name and name.endswith(".bias"):
|
||||
# transform conv2d bias [n_embd] --> [1, 1, n_embd]
|
||||
data_torch = data_torch.unsqueeze(-1).unsqueeze(-1)
|
||||
yield from Qwen25AudioModel.modify_tensors(self, data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3ASRForConditionalGeneration")
|
||||
class Qwen3ASRMmprojModel(Qwen3OmniMmprojModel):
|
||||
has_audio_encoder = True
|
||||
has_vision_encoder = False
|
||||
|
||||
|
||||
@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration")
|
||||
@@ -4992,6 +5067,8 @@ class Step3VLVisionModel(MmprojModel):
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
if ".position_embd." in new_name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
if ("mm.0." in new_name or "mm.1." in new_name) and new_name.endswith(".weight"):
|
||||
return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
@@ -5030,9 +5107,10 @@ class Qwen3VLTextModel(Qwen3Model):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
|
||||
vision_config = self.hparams.get("vision_config", {})
|
||||
if "thinker_config" in self.hparams:
|
||||
vision_config = self.hparams["thinker_config"].get("vision_config", {})
|
||||
else:
|
||||
vision_config = self.hparams.get("vision_config", {})
|
||||
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
|
||||
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
|
||||
|
||||
@@ -5101,6 +5179,70 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3OmniMoeForConditionalGeneration")
|
||||
class Qwen3OmniMoeTextModel(Qwen3VLMoeTextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE
|
||||
|
||||
def set_vocab(self):
|
||||
super().set_vocab()
|
||||
# correct BOS/EOS tokens
|
||||
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
|
||||
tokenizer_config = json.load(f)
|
||||
added_tokens = tokenizer_config.get("added_tokens_decoder", {})
|
||||
for token_id, data in added_tokens.items():
|
||||
if data.get("content") == "<|im_end|>":
|
||||
self.gguf_writer.add_bos_token_id(int(token_id))
|
||||
self.gguf_writer.add_eos_token_id(int(token_id))
|
||||
break
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_num_deepstack_layers(0)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Skip vision and audio tensors - they go in the mmproj file
|
||||
if "visual." in name or "audio_tower." in name \
|
||||
or "talker." in name or "code2wav." in name:
|
||||
return
|
||||
|
||||
name = name.replace("thinker.", "")
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3ASRForConditionalGeneration")
|
||||
class Qwen3ASRTextModel(Qwen3VLTextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3VL
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_num_deepstack_layers(0)
|
||||
|
||||
def set_vocab(self):
|
||||
super().set_vocab()
|
||||
# fix chat template, use correct chatml format
|
||||
self.gguf_writer.add_chat_template("{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}")
|
||||
# correct BOS/EOS tokens
|
||||
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
|
||||
tokenizer_config = json.load(f)
|
||||
added_tokens = tokenizer_config.get("added_tokens_decoder", {})
|
||||
for token_id, data in added_tokens.items():
|
||||
if data.get("content") == "<|im_end|>":
|
||||
self.gguf_writer.add_bos_token_id(int(token_id))
|
||||
self.gguf_writer.add_eos_token_id(int(token_id))
|
||||
break
|
||||
|
||||
def modify_tensors(self, data_torch, name, bid):
|
||||
# qwen3-omni
|
||||
name = name.replace("thinker.", "")
|
||||
|
||||
# Skip vision and audio tensors - they go in the mmproj file
|
||||
if "visual." in name or "audio_tower." in name \
|
||||
or "talker." in name or "code2wav." in name:
|
||||
return
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
class _LinearAttentionVReorderBase(Qwen3NextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3NEXT # overridden by subclasses
|
||||
"""reorders V heads from grouped to tiled order for ggml broadcast
|
||||
|
||||
@@ -5,6 +5,7 @@ Adding a model requires few steps:
|
||||
1. Convert the model to GGUF
|
||||
2. Define the model architecture in `llama.cpp`
|
||||
3. Build the GGML graph implementation
|
||||
4. Optional: Add multimodal encoder implementation
|
||||
|
||||
After following these steps, you can open PR.
|
||||
|
||||
@@ -114,6 +115,21 @@ Some `ggml` backends do not support all operations. Backend implementations can
|
||||
|
||||
Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/).
|
||||
|
||||
### 4. Optional: Add multimodal encoder implementation
|
||||
|
||||
If the new model supports multimodal inputs, you will need to add a new encoder definition in `libmtmd`. You can find more information about llama.cpp's multimodal support in [the docs](../multimodal.md) and in the `tools/mtmd` source directory.
|
||||
|
||||
1. In the conversion script, make sure you add a subclass that extends `MmprojModel` or another class that inherits from the same base class.
|
||||
2. Add the encoder definition in `clip.cpp`.
|
||||
3. Implement the preprocessor in `mtmd.cpp`. In most cases, you can reuse an existing preprocessor.
|
||||
4. Implement the encoder GGML graph, either in a dedicated file if the model is truly different from existing ones, or by reusing an existing implementation (for example: siglip, pixtral, or qwen) and adding a model-specific projector.
|
||||
|
||||
Note:
|
||||
- Many multimodal encoders are based on models that are already supported. Make sure to read the existing encoder definitions in `tools/mtmd/models` before adding a new one. In `libmtmd`, it is generally better to extend an existing model than to duplicate code.
|
||||
- To debug the multimodal preprocessor and encoder, you can use [llama-mtmd-debug](tools/mtmd/debug/mtmd-debug.cpp).
|
||||
- Adding a model-specific API or CLI is an anti-pattern in `libmtmd`. The goal of `libmtmd` is to provide an easy-to-use, model-agnostic library for multimodal pipeline.
|
||||
- In most cases, `llama-mtmd-cli` should not be modified. If a model requires a specific prompt, either let the user provide it or bake it into the Jinja chat template.
|
||||
|
||||
## GGUF specification
|
||||
|
||||
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md
|
||||
|
||||
@@ -94,6 +94,11 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
||||
# Moondream2 20250414 version
|
||||
(tool_name) -hf ggml-org/moondream2-20250414-GGUF
|
||||
|
||||
# Gemma 4
|
||||
(tool_name) -hf ggml-org/gemma-4-E2B-it-GGUF
|
||||
(tool_name) -hf ggml-org/gemma-4-E4B-it-GGUF
|
||||
(tool_name) -hf ggml-org/gemma-4-26B-A4B-it-GGUF
|
||||
(tool_name) -hf ggml-org/gemma-4-31B-it-GGUF
|
||||
```
|
||||
|
||||
**Audio models**:
|
||||
@@ -109,6 +114,10 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
||||
|
||||
# Mistral's Voxtral
|
||||
(tool_name) -hf ggml-org/Voxtral-Mini-3B-2507-GGUF
|
||||
|
||||
# Qwen3-ASR
|
||||
(tool_name) -hf ggml-org/Qwen3-ASR-0.6B-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen3-ASR-1.7B-GGUF
|
||||
```
|
||||
|
||||
**Mixed modalities**:
|
||||
@@ -118,6 +127,16 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
||||
# Capabilities: audio input, vision input
|
||||
(tool_name) -hf ggml-org/Qwen2.5-Omni-3B-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen2.5-Omni-7B-GGUF
|
||||
|
||||
# Qwen3 Omni
|
||||
# Capabilities: audio input, vision input
|
||||
(tool_name) -hf ggml-org/Qwen3-Omni-30B-A3B-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen3-Omni-30B-A3B-Thinking-GGUF
|
||||
|
||||
# Gemma 4
|
||||
# Capabilities: audio input, vision input
|
||||
(tool_name) -hf ggml-org/gemma-4-E2B-it-GGUF
|
||||
(tool_name) -hf ggml-org/gemma-4-E4B-it-GGUF
|
||||
```
|
||||
|
||||
## Finding more models:
|
||||
|
||||
@@ -58,26 +58,48 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
|
||||
bool is_capturing = false;
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
// Currently (confirmed for CCCL <= 3.2) DeviceSegmentedSort does not support stream capture, while DeviceSegmentedRadixSort does.
|
||||
// See https://github.com/NVIDIA/cccl/issues/5661#issuecomment-3229037149
|
||||
// TODO: constrain this to the CCCL versions that have this issue once it's resolved in a future CCCL release.
|
||||
cudaStreamCaptureStatus capture_status;
|
||||
CUDA_CHECK(cudaStreamIsCapturing(stream, &capture_status));
|
||||
is_capturing = (capture_status != cudaStreamCaptureStatusNone);
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(
|
||||
nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
|
||||
stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,22 +108,33 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, stream));
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys,
|
||||
temp_keys, temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,8 +134,9 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
||||
switch (nc) {
|
||||
case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
|
||||
case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
|
||||
case 5: launch_kernel(std::integral_constant<int, 5>{}); break;
|
||||
case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
|
||||
default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
|
||||
default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now.");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -488,7 +488,7 @@ static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t
|
||||
const int nb = k / QK_NVFP4;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
dequantize_block_nvfp4(vx, y, k);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#define GGML_SYCL_DEQUANTIZE_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
#include "convert.hpp"
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
||||
typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs,
|
||||
|
||||
@@ -355,7 +355,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -176,14 +176,12 @@ static void launch_gated_delta_net(const float * q_d,
|
||||
const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1);
|
||||
const sycl::uint3 rq3_magic = init_fastdiv_values(rq3);
|
||||
|
||||
int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;
|
||||
|
||||
switch (S_v) {
|
||||
case 16:
|
||||
{
|
||||
constexpr int sv = 16;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
|
||||
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
|
||||
sb3, neqk1_magic, rq3_magic, scale);
|
||||
@@ -194,7 +192,7 @@ static void launch_gated_delta_net(const float * q_d,
|
||||
{
|
||||
constexpr int sv = 32;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
|
||||
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
|
||||
sb3, neqk1_magic, rq3_magic, scale);
|
||||
@@ -205,7 +203,7 @@ static void launch_gated_delta_net(const float * q_d,
|
||||
{
|
||||
constexpr int sv = 64;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
|
||||
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
@@ -217,7 +215,7 @@ static void launch_gated_delta_net(const float * q_d,
|
||||
{
|
||||
constexpr int sv = 128;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
|
||||
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
|
||||
@@ -4727,12 +4727,19 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
struct ggml_tensor * a = op->src[0];
|
||||
struct ggml_tensor * b = op->src[1];
|
||||
|
||||
// disable Q1_0 until implementation
|
||||
if (a->type == GGML_TYPE_Q1_0 || b->type == GGML_TYPE_Q1_0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (a->ne[3] != b->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
|
||||
|
||||
|
||||
// TODO: The configuration below needs more work to be supported with oneDNN
|
||||
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
|
||||
a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {
|
||||
|
||||
@@ -272,7 +272,7 @@ static void upscale_f32_sycl(const float * x,
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
|
||||
});
|
||||
}
|
||||
@@ -304,7 +304,7 @@ static void upscale_f32_bilinear_sycl(const float * x,
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
upscale_f32_bilinear_antialias(
|
||||
x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst,
|
||||
ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
@@ -314,7 +314,7 @@ static void upscale_f32_bilinear_sycl(const float * x,
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
upscale_f32_bilinear(
|
||||
x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst,
|
||||
ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
@@ -349,7 +349,7 @@ static void upscale_f32_bicubic_sycl(const float * x,
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
upscale_f32_bicubic(
|
||||
x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst,
|
||||
ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
|
||||
@@ -2858,11 +2858,10 @@ struct vk_fa_tuning_params {
|
||||
}
|
||||
};
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type);
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
GGML_UNUSED(kv_type);
|
||||
|
||||
vk_fa_tuning_params result{};
|
||||
result.path = FA_SCALAR;
|
||||
@@ -2914,7 +2913,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
|
||||
|
||||
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
|
||||
|
||||
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
|
||||
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) {
|
||||
result.block_rows /= 2;
|
||||
}
|
||||
|
||||
@@ -3445,21 +3444,47 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
if (device->fp16) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product && device->subgroup_clustered) {
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8)
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
|
||||
}
|
||||
} else {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product && device->subgroup_clustered) {
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8)
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
|
||||
}
|
||||
}
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (device->coopmat1_fa_support) {
|
||||
@@ -8780,7 +8805,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
|
||||
GGML_UNUSED(f32acc);
|
||||
// Needs to be kept up to date on shader changes
|
||||
const uint32_t wg_size = params.workgroup_size;
|
||||
@@ -8789,21 +8814,51 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
||||
|
||||
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
|
||||
const bool mmq = device->integer_dot_product && device->subgroup_clustered &&
|
||||
(kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 ||
|
||||
kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 ||
|
||||
kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL);
|
||||
|
||||
// tmpsh is overestimated slightly
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * float_type_size;
|
||||
|
||||
const uint32_t masksh = Bc * (Br + 1) * float_type_size;
|
||||
|
||||
const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
||||
uint32_t Qf, kvsh, kblocksh_size;
|
||||
if (mmq) {
|
||||
// block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds
|
||||
const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size;
|
||||
Qf = Br * (hsk / 32) * block_b_size;
|
||||
|
||||
const uint32_t D = std::max(hsk, hsv);
|
||||
const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
// kvsh uses D = HSV (K goes through kblocksh instead)
|
||||
kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
|
||||
// block_a_cache size depends on quant type
|
||||
uint32_t block_a_size;
|
||||
switch (kv_type) {
|
||||
case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break;
|
||||
case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break;
|
||||
case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break;
|
||||
case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break;
|
||||
default: block_a_size = 0; break;
|
||||
}
|
||||
kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
|
||||
} else {
|
||||
Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
||||
|
||||
const uint32_t D = std::max(hsk, hsv);
|
||||
kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
|
||||
kblocksh_size = 0;
|
||||
}
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
@@ -10,6 +10,13 @@
|
||||
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
|
||||
#endif
|
||||
|
||||
#ifdef MMQ
|
||||
#extension GL_EXT_integer_dot_product : require
|
||||
#extension GL_KHR_shader_subgroup_clustered : require
|
||||
|
||||
#include "mul_mmq_shmem_types.glsl"
|
||||
#endif
|
||||
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
|
||||
@@ -41,15 +48,34 @@ shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
|
||||
const uint32_t masksh_stride = Br + 1;
|
||||
shared FLOAT_TYPE masksh[Bc * masksh_stride];
|
||||
|
||||
#ifndef MMQ
|
||||
const uint32_t qf_stride = HSK / 4 + 1;
|
||||
shared FLOAT_TYPEV4 Qf[Br * qf_stride];
|
||||
#else
|
||||
|
||||
const uint32_t qf_stride = HSK / 32;
|
||||
shared block_b_cache Qf[Br * qf_stride];
|
||||
#endif
|
||||
|
||||
#ifndef MMQ
|
||||
const uint32_t D = HSK > HSV ? HSK : HSV;
|
||||
#else
|
||||
const uint32_t D = HSV;
|
||||
#endif
|
||||
const uint32_t kvsh_stride = D / 4 + 1;
|
||||
shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
|
||||
|
||||
#ifdef MMQ
|
||||
|
||||
shared block_a_cache kblocksh[SHMEM_STAGING != 0 ? Bc * qf_stride : 1];
|
||||
#endif
|
||||
|
||||
shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
|
||||
|
||||
#ifdef MMQ
|
||||
#include "flash_attn_mmq_funcs.glsl"
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
@@ -82,10 +108,39 @@ void main() {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t r = (idx + tid) / (HSK / 4);
|
||||
if (r < Br && d < HSK / 4 &&
|
||||
i * Br + r < N) {
|
||||
const bool is_in_bounds = r < Br && d < HSK / 4 && i * Br + r < N;
|
||||
#ifndef MMQ
|
||||
if (is_in_bounds) {
|
||||
Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||
}
|
||||
#else
|
||||
const uint buf_ib = r * qf_stride + d / 8;
|
||||
const uint buf_iqs = d % 8;
|
||||
|
||||
FLOAT_TYPEV4 vals = is_in_bounds ? FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale) : FLOAT_TYPEV4(0.0f);
|
||||
const FLOAT_TYPEV4 abs_vals = abs(vals);
|
||||
|
||||
const FLOAT_TYPE thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
|
||||
const FLOAT_TYPE amax = subgroupClusteredMax(thread_max, 8);
|
||||
const FLOAT_TYPE qd = amax / FLOAT_TYPE(127.0);
|
||||
const FLOAT_TYPE qd_inv = qd != FLOAT_TYPE(0.0) ? FLOAT_TYPE(1.0) / qd : FLOAT_TYPE(0.0);
|
||||
vals = round(vals * qd_inv);
|
||||
|
||||
Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));
|
||||
|
||||
#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
|
||||
}
|
||||
#else // Q4_0, Q4_1, Q5_0, Q5_1
|
||||
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
|
||||
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
|
||||
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
barrier();
|
||||
|
||||
@@ -195,6 +250,7 @@ void main() {
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
barrier();
|
||||
#ifndef MMQ
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
@@ -214,9 +270,29 @@ void main() {
|
||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
#else // MMQ
|
||||
const uint ints_per_block = 8 / QUANT_R_MMQ;
|
||||
const uint quant_iters = Bc * HSK / 32 * ints_per_block;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
|
||||
const uint32_t iqs = (idx + tid) % ints_per_block;
|
||||
const uint32_t ib = (idx + tid) / ints_per_block;
|
||||
const uint32_t c = ib / (HSK / 32);
|
||||
const uint32_t block = ib % (HSK / 32);
|
||||
if (idx + gl_WorkGroupSize.x <= quant_iters || c < Bc) {
|
||||
const uint buf_ib = c * qf_stride + block;
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
const uint global_ib = (j * Bc + c) * k_stride + block;
|
||||
k_block_to_shmem(buf_ib, global_ib, iqs, k_offset);
|
||||
} else {
|
||||
k_block_to_shmem_zero(buf_ib, iqs);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // MMQ
|
||||
barrier();
|
||||
}
|
||||
|
||||
#ifndef MMQ
|
||||
// More d iterations means Q register caching becomes relevant
|
||||
// Few iterations means the additional registers needed are worse than the speed-up from caching
|
||||
if (HSK_per_thread / 4 > 4) {
|
||||
@@ -275,6 +351,110 @@ void main() {
|
||||
}
|
||||
}
|
||||
}
|
||||
#else // MMQ
|
||||
const uint hsk4 = HSK_per_thread / 4;
|
||||
const uint d_per_step = (hsk4 % 8 == 0) ? 8 :
|
||||
(hsk4 % 4 == 0) ? 4 :
|
||||
(hsk4 % 2 == 0) ? 2 : 1;
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d_block = 0; d_block < HSK_per_thread / 4; d_block += d_per_step) {
|
||||
int32_t k_quants[d_per_step];
|
||||
ACC_TYPEV2 k_dm;
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
|
||||
const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
|
||||
#if QUANT_AUXF == 1
|
||||
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0);
|
||||
#else
|
||||
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm);
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
if (d_per_step == 8) {
|
||||
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
|
||||
uint vui = kblocksh[buf_ib].qs[d];
|
||||
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
|
||||
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
|
||||
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
|
||||
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
#endif
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
|
||||
const uint ib = coord / BLOCK_SIZE;
|
||||
const uint iqs = (coord % BLOCK_SIZE);
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0);
|
||||
#else
|
||||
k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset));
|
||||
#endif
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
if (d_per_step == 8) {
|
||||
#if defined(DATA_A_Q5_0)
|
||||
uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0],
|
||||
k_packed.k_data_packed16[k_offset + ib].qh[1]));
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
uint qh = k_packed.k_data_packed16[k_offset + ib].qh;
|
||||
#endif
|
||||
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d];
|
||||
#else
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0],
|
||||
k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]));
|
||||
#endif
|
||||
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
|
||||
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint qh_lo = (qh >> (d * 4)) & 0xF;
|
||||
uint qh_hi = (qh >> (d * 4 + 16)) & 0xF;
|
||||
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
#endif
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint qib = tile_row(r) * qf_stride + (d_tid * (HSK_per_thread / 4) + d_block) / 8;
|
||||
const uint qiqs = (d_tid * (HSK_per_thread / 4) + d_block) % 8;
|
||||
|
||||
int32_t acc = 0;
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
acc += dotPacked4x8EXT(Qf[qib].qs[qiqs + d], k_quants[d]);
|
||||
}
|
||||
|
||||
Sf[r][c] += ACC_TYPE(acc) * ACC_TYPE(Qf[qib].ds.x) * k_dm.x;
|
||||
if ((d_tid * (HSK_per_thread / 4) + d_block) % 8 == 0) {
|
||||
Sf[r][c] += k_dot_correction(qib, k_dm);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // MMQ
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
// Compute sum across the D_split
|
||||
|
||||
@@ -89,6 +89,11 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
|
||||
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||
#endif
|
||||
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32;
|
||||
layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32;
|
||||
#endif
|
||||
|
||||
#ifndef BLOCK_SIZE
|
||||
#define BLOCK_SIZE 1
|
||||
#endif
|
||||
|
||||
149
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl
Normal file
149
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl
Normal file
@@ -0,0 +1,149 @@
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
#ifdef DATA_A_Q4_0
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
#else
|
||||
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
#endif
|
||||
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
return int32_t(vui & 0x0F0F0F0F);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
#ifdef DATA_A_Q5_0
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qh[1]));
|
||||
#else
|
||||
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
|
||||
#endif
|
||||
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
uint qh_bits = (qh >> iqs) & 0xF;
|
||||
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
u8vec4 idx = unpack8(vui & 0x0F0F0F0F);
|
||||
return pack32(i8vec4(kvalues_iq4nl_const[idx.x],
|
||||
kvalues_iq4nl_const[idx.y],
|
||||
kvalues_iq4nl_const[idx.z],
|
||||
kvalues_iq4nl_const[idx.w]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
FLOAT_TYPE get_k_d(uint ib, uint a_offset) {
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d);
|
||||
}
|
||||
#else
|
||||
FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) {
|
||||
return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) {
|
||||
#if defined(DATA_A_Q4_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qh[1]));
|
||||
}
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh;
|
||||
}
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
|
||||
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
|
||||
kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y],
|
||||
kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w]));
|
||||
kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y],
|
||||
kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w]));
|
||||
#endif
|
||||
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d);
|
||||
#else
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) {
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4 : 0;
|
||||
return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
|
||||
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4 : 0;
|
||||
int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
|
||||
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF;
|
||||
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
|
||||
return kblocksh[buf_ib].qs[pos];
|
||||
#endif
|
||||
}
|
||||
|
||||
ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) {
|
||||
#if defined(DATA_A_Q4_0)
|
||||
return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||
return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
|
||||
#else
|
||||
return ACC_TYPE(0.0);
|
||||
#endif
|
||||
}
|
||||
|
||||
void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) {
|
||||
kblocksh[buf_ib].qs[iqs] = 0;
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
kblocksh[buf_ib].qs[iqs + 4] = 0;
|
||||
#endif
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f);
|
||||
#else
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -32,6 +32,12 @@ struct block_a_cache {
|
||||
int32_t qs[32/4];
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
|
||||
@@ -1692,6 +1692,7 @@ struct block_iq4_nl_packed16
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
#define QUANT_K QUANT_K_IQ4_NL
|
||||
#define QUANT_R QUANT_R_IQ4_NL
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_iq4_nl
|
||||
#define A_TYPE_PACKED16 block_iq4_nl_packed16
|
||||
#endif
|
||||
|
||||
@@ -406,8 +406,8 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
|
||||
}
|
||||
|
||||
static std::vector<std::future<void>> compiles;
|
||||
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
|
||||
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
||||
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, const std::string& suffix = "") {
|
||||
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")) + suffix;
|
||||
std::string out_path = join_paths(output_dir, name + ".spv");
|
||||
|
||||
if (input_filepath == "") {
|
||||
@@ -625,15 +625,16 @@ void process_shaders() {
|
||||
for (const bool& fp16 : {false, true}) {
|
||||
std::map<std::string, std::string> base_dict;
|
||||
if (fp16) {
|
||||
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
|
||||
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV2", "f16vec2"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
|
||||
} else {
|
||||
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
|
||||
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"FLOAT_TYPEV4", "vec4"}};
|
||||
}
|
||||
|
||||
// flash attention
|
||||
for (const bool& f16acc : {false, true}) {
|
||||
std::map<std::string, std::string> fa_base_dict = base_dict;
|
||||
fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float";
|
||||
fa_base_dict["ACC_TYPEV2"] = fp16 && f16acc ? "f16vec2" : "vec2";
|
||||
fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4";
|
||||
if (fp16 && f16acc) {
|
||||
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
@@ -672,6 +673,12 @@ void process_shaders() {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (tname != "f32") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -534,11 +534,7 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
|
||||
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
|
||||
wgpu::CommandBuffer commands = encoder.Finish();
|
||||
ctx->queue.Submit(1, &commands);
|
||||
if (!ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0,
|
||||
ctx->debug_host_buf.GetSize())) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Debug buffer map failed\n");
|
||||
return;
|
||||
}
|
||||
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
|
||||
const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
|
||||
std::cout << "debug[0]: " << debug_data[0] << "\n";
|
||||
ctx->debug_host_buf.Unmap();
|
||||
|
||||
@@ -798,6 +798,8 @@ class MODEL_TENSOR(IntEnum):
|
||||
A_ENC_INP_PROJ = auto() # gemma4
|
||||
A_ENC_CONV1D = auto()
|
||||
A_ENC_CONV1D_NORM = auto() # gemma3n
|
||||
A_ENC_CONV2D = auto()
|
||||
A_ENC_CONV_OUT = auto()
|
||||
A_PRE_NORM = auto()
|
||||
A_POST_NORM = auto()
|
||||
A_ENC_LAYER_PRE_NORM = auto() # gemma3n
|
||||
@@ -1280,6 +1282,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits",
|
||||
MODEL_TENSOR.A_ENC_INP_PROJ: "a.input_projection",
|
||||
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
|
||||
MODEL_TENSOR.A_ENC_CONV2D: "a.conv2d.{bid}",
|
||||
MODEL_TENSOR.A_ENC_CONV_OUT: "a.conv_out",
|
||||
MODEL_TENSOR.A_ENC_CONV1D_NORM: "a.conv1d.{bid}.norm",
|
||||
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
|
||||
MODEL_TENSOR.A_POST_NORM: "a.post_ln",
|
||||
@@ -1426,6 +1430,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS,
|
||||
MODEL_TENSOR.A_ENC_INP_PROJ,
|
||||
MODEL_TENSOR.A_ENC_CONV1D,
|
||||
MODEL_TENSOR.A_ENC_CONV2D,
|
||||
MODEL_TENSOR.A_ENC_CONV_OUT,
|
||||
MODEL_TENSOR.A_ENC_CONV1D_NORM,
|
||||
MODEL_TENSOR.A_PRE_NORM,
|
||||
MODEL_TENSOR.A_POST_NORM,
|
||||
@@ -4112,6 +4118,7 @@ class VisionProjectorType:
|
||||
ULTRAVOX = "ultravox"
|
||||
INTERNVL = "internvl"
|
||||
QWEN2A = "qwen2a" # audio
|
||||
QWEN3A = "qwen3a" # audio
|
||||
GLMA = "glma" # audio
|
||||
QWEN25O = "qwen2.5o" # omni
|
||||
VOXTRAL = "voxtral"
|
||||
|
||||
@@ -1892,6 +1892,14 @@ class TensorNameMap:
|
||||
"conformer.subsample_conv_projection.input_proj_linear", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_CONV2D: (
|
||||
"audio_tower.conv2d{bid}", # qwen3omni
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_CONV_OUT: (
|
||||
"audio_tower.conv_out", # qwen3omni
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_PRE_NORM: (),
|
||||
|
||||
MODEL_TENSOR.A_POST_NORM: (
|
||||
@@ -2042,7 +2050,8 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.A_MMPROJ: (
|
||||
"audio.multi_modal_projector.linear_{bid}", # ultravox, meralion
|
||||
"audio_adapter.model.{bid}" # lfm2
|
||||
"audio_adapter.model.{bid}", # lfm2
|
||||
"audio_tower.proj{bid}", # qwen3omni
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_MMPROJ_FC: (
|
||||
|
||||
141
models/templates/deepseek-ai-DeepSeek-V3.2.jinja
Normal file
141
models/templates/deepseek-ai-DeepSeek-V3.2.jinja
Normal file
@@ -0,0 +1,141 @@
|
||||
{%- if not add_generation_prompt is defined -%}
|
||||
{%- set add_generation_prompt = false -%}
|
||||
{%- endif -%}
|
||||
{%- if not thinking is defined -%}
|
||||
{%- if enable_thinking is defined -%}
|
||||
{%- set thinking = enable_thinking -%}
|
||||
{%- else -%}
|
||||
{%- set thinking = false -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- set dsml_token = '|DSML|' -%}
|
||||
{%- set thinking_start_token = '<think>' -%}
|
||||
{%- set thinking_end_token = '</think>' -%}
|
||||
{%- set tools_header = '## Tools\n\nYou have access to a set of tools you can use to answer the user\'s question.\nYou can invoke functions by writing a "<' + dsml_token + 'function_calls>" block like the following as part of your reply to the user:\n<' + dsml_token + 'function_calls>\n<' + dsml_token + 'invoke name="$FUNCTION_NAME">\n<' + dsml_token + 'parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</' + dsml_token + 'parameter>\n...\n</' + dsml_token + 'invoke>\n<' + dsml_token + 'invoke name="$FUNCTION_NAME2">\n...\n</' + dsml_token + 'invoke>\n</' + dsml_token + 'function_calls>\n\nString and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects).\n\nIf the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example:\n\n<' + dsml_token + 'function_calls>\n...\n</' + dsml_token + 'function_calls>\n\n<function_results>\n...\n</function_results>\n\n' + thinking_start_token + '...thinking about results' + thinking_end_token + '\n\nHere are the functions available in JSONSchema format:\n<functions>\n' -%}
|
||||
{%- set tools_footer = '</functions>\n' -%}
|
||||
{%- set ns = namespace(system_prompt='', is_first_sp=true) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{%- if ns.is_first_sp -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + (message['content'] or '') -%}
|
||||
{%- set ns.is_first_sp = false -%}
|
||||
{%- else -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + (message['content'] or '') -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if tools is defined and tools -%}
|
||||
{%- set ts = namespace(schemas='') -%}
|
||||
{%- for tool in tools -%}
|
||||
{%- if tool['type'] == 'function' -%}
|
||||
{%- set ts.schemas = ts.schemas + (tool['function'] | tojson) + '\n' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if ns.system_prompt -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + tools_header + ts.schemas + tools_footer -%}
|
||||
{%- else -%}
|
||||
{%- set ns.system_prompt = tools_header + ts.schemas + tools_footer -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{{- bos_token -}}
|
||||
{{- ns.system_prompt -}}
|
||||
{%- set last_user_idx = namespace(value=-1) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' or message['role'] == 'developer' -%}
|
||||
{%- set last_user_idx.value = loop.index0 -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set state = namespace(pending_asst_marker=false, pending_tool_marker=false) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{- '<|User|>' + (message['content'] or '') -}}
|
||||
{%- set state.pending_asst_marker = true -%}
|
||||
{%- set state.pending_tool_marker = false -%}
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
{%- set is_after_last_user = loop.index0 > last_user_idx.value -%}
|
||||
{%- if state.pending_asst_marker -%}
|
||||
{{- '<|Assistant|>' -}}
|
||||
{%- if is_after_last_user and thinking -%}
|
||||
{{- thinking_start_token -}}
|
||||
{%- if message['reasoning_content'] is defined and message['reasoning_content'] -%}
|
||||
{{- message['reasoning_content'] -}}
|
||||
{%- endif -%}
|
||||
{{- thinking_end_token -}}
|
||||
{%- else -%}
|
||||
{{- thinking_end_token -}}
|
||||
{%- endif -%}
|
||||
{%- elif state.pending_tool_marker -%}
|
||||
{%- if is_after_last_user and thinking -%}
|
||||
{{- '\n\n' + thinking_start_token -}}
|
||||
{%- if message['reasoning_content'] is defined and message['reasoning_content'] -%}
|
||||
{{- message['reasoning_content'] -}}
|
||||
{%- endif -%}
|
||||
{{- thinking_end_token -}}
|
||||
{%- else -%}
|
||||
{{- '\n\n' + thinking_end_token -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- set state.pending_asst_marker = false -%}
|
||||
{%- set state.pending_tool_marker = false -%}
|
||||
{%- if message['content'] is defined and message['content'] -%}
|
||||
{{- message['content'] -}}
|
||||
{%- endif -%}
|
||||
{%- if message['tool_calls'] -%}
|
||||
{{- '\n\n<' + dsml_token + 'function_calls>\n' -}}
|
||||
{%- for tool in message['tool_calls'] -%}
|
||||
{%- set func = tool['function'] -%}
|
||||
{{- '<' + dsml_token + 'invoke name="' + func['name'] + '">\n' -}}
|
||||
{%- set args = func['arguments'] -%}
|
||||
{%- if args is string -%}
|
||||
{%- set args = args | from_json -%}
|
||||
{%- endif -%}
|
||||
{%- for key, val in args.items() -%}
|
||||
{%- if val is string -%}
|
||||
{{- '<' + dsml_token + 'parameter name="' + key + '" string="true">' + val + '</' + dsml_token + 'parameter>\n' -}}
|
||||
{%- else -%}
|
||||
{{- '<' + dsml_token + 'parameter name="' + key + '" string="false">' + (val | tojson) + '</' + dsml_token + 'parameter>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '</' + dsml_token + 'invoke>\n' -}}
|
||||
{%- endfor -%}
|
||||
{{- '</' + dsml_token + 'function_calls>' -}}
|
||||
{%- endif -%}
|
||||
{{- '<|end▁of▁sentence|>' -}}
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
{%- set outer_index = loop.index0 -%}
|
||||
{%- set assistant_idx = namespace(value=-1) -%}
|
||||
{%- for prev_msg in messages -%}
|
||||
{%- if prev_msg['role'] == 'assistant' and prev_msg['tool_calls'] and loop.index0 < outer_index -%}
|
||||
{%- set assistant_idx.value = loop.index0 -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set call_order = outer_index - assistant_idx.value -%}
|
||||
{%- set assistant_msg = messages[assistant_idx.value] -%}
|
||||
{%- set tool_call_count = assistant_msg['tool_calls'] | length -%}
|
||||
{%- if call_order == 1 -%}
|
||||
{{- '\n\n<function_results>' -}}
|
||||
{%- endif -%}
|
||||
{{- '\n<result>' + (message['content'] or '') + '</result>' -}}
|
||||
{%- if call_order == tool_call_count -%}
|
||||
{{- '\n</function_results>' -}}
|
||||
{%- set state.pending_asst_marker = false -%}
|
||||
{%- set state.pending_tool_marker = true -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if state.pending_asst_marker -%}
|
||||
{{- '<|Assistant|>' -}}
|
||||
{%- if thinking -%}
|
||||
{{- thinking_start_token -}}
|
||||
{%- else -%}
|
||||
{{- thinking_start_token + thinking_end_token -}}
|
||||
{%- endif -%}
|
||||
{%- elif state.pending_tool_marker -%}
|
||||
{%- if thinking -%}
|
||||
{{- '\n\n' + thinking_start_token -}}
|
||||
{%- else -%}
|
||||
{{- '\n\n' + thinking_start_token + thinking_end_token -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
@@ -258,6 +258,66 @@ void test_gbnf_generation(testing &t) {
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("silent parser emits nothing in gbnf", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.gbnf(p.literal("world"), "");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("silent choice inside sequence emits nothing", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("a") + p.gbnf(p.literal("b") | p.literal("c"), "") + p.literal("d");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a" "d"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("silent wrapped in tag emits nothing", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("a") + p.tag("t", p.gbnf(p.literal("b"), ""));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("gbnf parser emits custom grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("a") + p.gbnf(p.literal("b"), "[a-z]+");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a" [a-z]+
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("nested transparent wrappers get parenthesized", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("x") + p.tag("outer", p.atomic(p.literal("a") | p.literal("b")));
|
||||
|
||||
@@ -8397,6 +8397,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 512, 1, 1}, order)); // test CUDA dispatching to radix sort for nrows > = 1 in graph mode
|
||||
}
|
||||
|
||||
for (int n = 1; n < 5; ++n) {
|
||||
@@ -8579,7 +8580,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
for (int nb : { 1, 3, 32, 75, }) {
|
||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
||||
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL}) {
|
||||
if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue;
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
|
||||
|
||||
@@ -2118,6 +2118,31 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.tools({ amount_tool })
|
||||
.expect(message_with_tool_calls("amount", R"({"orig": 1.5e10})"))
|
||||
.run();
|
||||
|
||||
// Edge cases
|
||||
tst.test(
|
||||
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<channel|>")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<|channel>thought\n<channel|>")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<|channel>thought\n<channel|><channel|>")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"<|channel><|channel>thought\n<channel|>Hello, world!\nWhat's up?")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
}
|
||||
|
||||
{
|
||||
@@ -2576,6 +2601,215 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
expect(simple_assist_msg("CONTENT", "")).run();
|
||||
}
|
||||
|
||||
// DeepSeek V3.2 tests - format uses DSML markup:
|
||||
// <|DSML|function_calls>
|
||||
// <|DSML|invoke name="foo">
|
||||
// <|DSML|parameter name="bar" string="true|false">value</|DSML|parameter>
|
||||
// </|DSML|invoke>
|
||||
// </|DSML|function_calls>
|
||||
// Reasoning uses <think>...</think>. The generation prompt ends in <think> (thinking mode)
|
||||
// or <think></think> (non-thinking mode).
|
||||
{
|
||||
auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.2.jinja", detailed_debug);
|
||||
|
||||
// Pure content (non-thinking mode)
|
||||
tst.test("Hello, world!\nWhat's up?")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
// Thinking + content
|
||||
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.expect(message_assist_thoughts)
|
||||
.run();
|
||||
|
||||
// Thinking + tool call (single, string param)
|
||||
tst.test(
|
||||
"Let me check the time</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"get_time\">\n"
|
||||
"<|DSML|parameter name=\"city\" string=\"true\">Tokyo</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ get_time_tool })
|
||||
.expect(message_with_tool_calls_and_reasoning("get_time", R"({"city": "Tokyo"})", "Let me check the time"))
|
||||
.run();
|
||||
|
||||
// Tool call without reasoning (non-thinking mode), integer param (string="false")
|
||||
tst.test(
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"special_function\">\n"
|
||||
"<|DSML|parameter name=\"arg1\" string=\"false\">1</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.run();
|
||||
|
||||
// Multiple parallel tool calls with reasoning
|
||||
tst.test(
|
||||
"Calling both</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"get_time\">\n"
|
||||
"<|DSML|parameter name=\"city\" string=\"true\">Paris</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"<|DSML|invoke name=\"get_weather\">\n"
|
||||
"<|DSML|parameter name=\"city\" string=\"true\">Paris</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.parallel_tool_calls(true)
|
||||
.tools({ get_time_tool, get_weather_tool })
|
||||
.expect(message_with_reasoning_content_and_multiple_tool_calls(
|
||||
"Calling both", "",
|
||||
{ { "get_time", R"({"city": "Paris"})" }, { "get_weather", R"({"city": "Paris"})" } }))
|
||||
.run();
|
||||
|
||||
// Tool call with content before tool calls
|
||||
tst.test(
|
||||
"Thinking about it</think>"
|
||||
"Let me call the function.\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"special_function\">\n"
|
||||
"<|DSML|parameter name=\"arg1\" string=\"false\">1</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ special_function_tool })
|
||||
.expect_reasoning("Thinking about it")
|
||||
.expect_content("Let me call the function.")
|
||||
.expect_tool_calls({
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with negative number
|
||||
tst.test(
|
||||
"Test negative</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"magic_int\">\n"
|
||||
"<|DSML|parameter name=\"ref\" string=\"false\">-14</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ magic_int_tool })
|
||||
.expect_reasoning("Test negative")
|
||||
.expect_tool_calls({
|
||||
{ "magic_int", R"({"ref": -14})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with decimal number
|
||||
tst.test(
|
||||
"Test decimal</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"amount\">\n"
|
||||
"<|DSML|parameter name=\"orig\" string=\"false\">3.14</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ amount_tool })
|
||||
.expect_reasoning("Test decimal")
|
||||
.expect_tool_calls({
|
||||
{ "amount", R"({"orig": 3.14})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with boolean
|
||||
tst.test(
|
||||
"Test boolean</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"toggle\">\n"
|
||||
"<|DSML|parameter name=\"enabled\" string=\"false\">true</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ toggle_tool })
|
||||
.expect_reasoning("Test boolean")
|
||||
.expect_tool_calls({
|
||||
{ "toggle", R"({"enabled": true})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with array parameter (JSON-formatted)
|
||||
tst.test(
|
||||
"Test array</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"todo_list\">\n"
|
||||
"<|DSML|parameter name=\"todos\" string=\"false\">[\"buy milk\",\"walk dog\"]</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ todo_list })
|
||||
.expect_reasoning("Test array")
|
||||
.expect_tool_calls({
|
||||
{ "todo_list", R"({"todos": ["buy milk", "walk dog"]})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with object parameter (JSON-formatted)
|
||||
tst.test(
|
||||
"Test object</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"set_config\">\n"
|
||||
"<|DSML|parameter name=\"config\" string=\"false\">{\"theme\":\"dark\",\"level\":2}</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ config_tool })
|
||||
.expect_reasoning("Test object")
|
||||
.expect_tool_calls({
|
||||
{ "set_config", R"({"config": {"theme": "dark", "level": 2}})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Edge case: empty reasoning
|
||||
tst.test(
|
||||
"</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"get_time\">\n"
|
||||
"<|DSML|parameter name=\"city\" string=\"true\">XYZCITY</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ get_time_tool })
|
||||
.expect(message_with_tool_calls("get_time", R"({"city": "XYZCITY"})"))
|
||||
.run();
|
||||
|
||||
// Edge case: tool call with multiple params (mixed types, string first)
|
||||
tst.test(
|
||||
"Multi-arg call</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"magic_int\">\n"
|
||||
"<|DSML|parameter name=\"ref\" string=\"false\">42</|DSML|parameter>\n"
|
||||
"<|DSML|parameter name=\"name\" string=\"true\">foo bar</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ magic_int_tool })
|
||||
.expect_reasoning("Multi-arg call")
|
||||
.expect_tool_calls({
|
||||
{ "magic_int", R"({"ref": 42, "name": "foo bar"})", {} },
|
||||
})
|
||||
.run();
|
||||
}
|
||||
|
||||
// GLM-4.6 tests - format: <tool_call>function_name\n<arg_key>...</arg_key>\n<arg_value>...</arg_value>\n</tool_call>
|
||||
{
|
||||
auto tst = peg_tester("models/templates/GLM-4.6.jinja", detailed_debug);
|
||||
|
||||
@@ -88,6 +88,11 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) {
|
||||
uint32_t n_layer = 2;
|
||||
if (arch == LLM_ARCH_LLAMA4) {
|
||||
n_layer = 4; // hparams.n_no_rope_layer_step is hard-coded to 4
|
||||
} else if (arch == LLM_ARCH_GEMMA4) {
|
||||
n_embd = 128;
|
||||
n_head = 2;
|
||||
n_ff = 192;
|
||||
n_layer = 5; // need at least 5 for swa_pattern (every 5th is full_attention)
|
||||
} else if (arch == LLM_ARCH_GEMMA3N) {
|
||||
n_embd = 64;
|
||||
n_head = 1;
|
||||
@@ -169,7 +174,15 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) {
|
||||
ms.add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, uint32_t(8));
|
||||
ms.add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, n_ctx/8);
|
||||
|
||||
if (arch == LLM_ARCH_MIMO2 || arch == LLM_ARCH_STEP35) {
|
||||
if (arch == LLM_ARCH_GEMMA4) {
|
||||
ms.add_kv(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, n_embd/2);
|
||||
ms.add_kv(LLM_KV_ATTENTION_SHARED_KV_LAYERS, uint32_t(0));
|
||||
ms.add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, n_embd_head);
|
||||
ms.add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, n_embd_head);
|
||||
ms.add_kv(LLM_KV_ROPE_FREQ_BASE_SWA, 10000.0f);
|
||||
// SWA pattern: every 5th layer is full attention (matches E2B layer_types)
|
||||
ms.add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, uint32_t(5));
|
||||
} else if (arch == LLM_ARCH_MIMO2 || arch == LLM_ARCH_STEP35) {
|
||||
std::vector<uint32_t> pattern;
|
||||
pattern.reserve(n_layer);
|
||||
for (uint32_t il = 0; il < n_layer; il++) {
|
||||
@@ -429,6 +442,9 @@ static int save_models(const llm_arch target_arch, const size_t seed, const ggml
|
||||
if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) {
|
||||
continue;
|
||||
}
|
||||
if (arch == LLM_ARCH_GEMMA4) {
|
||||
continue; // FIXME: ISWA KV cache initialization needs more fixture params
|
||||
}
|
||||
for (bool moe : {false, true}) {
|
||||
if (moe && !moe_implemented(arch)) {
|
||||
continue;
|
||||
@@ -510,6 +526,9 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
|
||||
if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) {
|
||||
continue;
|
||||
}
|
||||
if (arch == LLM_ARCH_GEMMA4) {
|
||||
continue; // FIXME: ISWA KV cache initialization needs more fixture params
|
||||
}
|
||||
|
||||
const bool encode = arch == LLM_ARCH_T5 || arch == LLM_ARCH_DREAM || arch == LLM_ARCH_LLADA || arch == LLM_ARCH_LLADA_MOE || arch == LLM_ARCH_RND1;
|
||||
for (bool moe : {false, true}) {
|
||||
|
||||
@@ -18,6 +18,7 @@ add_library(mtmd
|
||||
models/cogvlm.cpp
|
||||
models/conformer.cpp
|
||||
models/dotsocr.cpp
|
||||
models/gemma4a.cpp
|
||||
models/gemma4v.cpp
|
||||
models/glm4v.cpp
|
||||
models/hunyuanocr.cpp
|
||||
@@ -32,6 +33,7 @@ add_library(mtmd
|
||||
models/pixtral.cpp
|
||||
models/qwen2vl.cpp
|
||||
models/qwen3vl.cpp
|
||||
models/qwen3a.cpp
|
||||
models/step3vl.cpp
|
||||
models/siglip.cpp
|
||||
models/whisper-enc.cpp
|
||||
|
||||
@@ -135,6 +135,8 @@
|
||||
|
||||
// ultravox
|
||||
#define TN_CONV1D "a.conv1d.%d.%s"
|
||||
#define TN_CONV2D "a.conv2d.%d.%s"
|
||||
#define TN_CONV_OUT "a.conv_out.%s"
|
||||
#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
|
||||
#define TN_MM_AUDIO_FC "mm.a.fc.%s" // fully connected layer
|
||||
#define TN_MM_NORM_PRE "mm.a.norm_pre.%s"
|
||||
@@ -181,6 +183,21 @@
|
||||
#define TN_CONV_PW1 "%s.blk.%d.conv_pw1.%s"
|
||||
#define TN_CONV_PW2 "%s.blk.%d.conv_pw2.%s"
|
||||
|
||||
// gemma4 audio conformer
|
||||
#define TN_A_MM_INP_PROJ "mm.a.input_projection.%s"
|
||||
#define TN_A_MM_SOFT_EMB_N "mm.a.soft_emb_norm.%s"
|
||||
#define TN_A_INP_PROJ "a.input_projection.%s"
|
||||
#define TN_A_CONV1D "a.conv1d.%d.%s"
|
||||
#define TN_A_CONV1D_NORM "a.conv1d.%d.norm.%s"
|
||||
#define TN_A_OUT_PROJ "a.pre_encode.out.%s"
|
||||
#define TN_A_ATTN_PRE_NORM "%s.blk.%d.attn_pre_norm.%s"
|
||||
#define TN_A_ATTN_POST_NORM "%s.blk.%d.attn_post_norm.%s"
|
||||
#define TN_A_ATTN_K_REL "%s.blk.%d.attn_k_rel.%s"
|
||||
#define TN_A_PER_DIM_SCALE "%s.blk.%d.per_dim_scale.%s"
|
||||
#define TN_A_PER_DIM_K_SCALE "%s.blk.%d.per_dim_k_scale.%s"
|
||||
#define TN_A_FFN_POST_NORM "%s.blk.%d.ffn_post_norm.%s"
|
||||
#define TN_A_FFN_POST_NORM_1 "%s.blk.%d.ffn_post_norm_1.%s"
|
||||
|
||||
// mobilenetv5 (gemma3n) definitions
|
||||
#define TN_MNV5_STEM_CONV "v.conv_stem.conv.weight"
|
||||
#define TN_MNV5_STEM_BIAS "v.conv_stem.conv.bias"
|
||||
@@ -256,6 +273,7 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_INTERNVL,
|
||||
PROJECTOR_TYPE_LLAMA4,
|
||||
PROJECTOR_TYPE_QWEN2A,
|
||||
PROJECTOR_TYPE_QWEN3A,
|
||||
PROJECTOR_TYPE_GLMA,
|
||||
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
|
||||
PROJECTOR_TYPE_VOXTRAL,
|
||||
@@ -300,6 +318,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
|
||||
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
|
||||
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
|
||||
{ PROJECTOR_TYPE_QWEN3A, "qwen3a"},
|
||||
{ PROJECTOR_TYPE_GLMA, "glma"},
|
||||
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
|
||||
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
|
||||
|
||||
@@ -217,6 +217,13 @@ struct clip_layer {
|
||||
ggml_tensor * conv_pw2_w = nullptr;
|
||||
ggml_tensor * conv_pw2_b = nullptr;
|
||||
|
||||
// gemma4 audio conformer per-layer
|
||||
ggml_tensor * attn_pre_norm_w = nullptr;
|
||||
ggml_tensor * attn_k_rel_w = nullptr;
|
||||
ggml_tensor * per_dim_scale_w = nullptr;
|
||||
ggml_tensor * per_dim_k_scale_w = nullptr;
|
||||
ggml_tensor * ff_post_norm_1_w = nullptr;
|
||||
|
||||
bool has_deepstack() const {
|
||||
return deepstack_fc1_w != nullptr;
|
||||
}
|
||||
@@ -406,10 +413,20 @@ struct clip_model {
|
||||
ggml_tensor * conv1d_1_b = nullptr;
|
||||
ggml_tensor * conv1d_2_w = nullptr;
|
||||
ggml_tensor * conv1d_2_b = nullptr;
|
||||
ggml_tensor * conv_out_w = nullptr;
|
||||
ggml_tensor * conv_out_b = nullptr;
|
||||
ggml_tensor * mm_norm_pre_w = nullptr;
|
||||
ggml_tensor * mm_norm_pre_b = nullptr;
|
||||
ggml_tensor * mm_norm_mid_w = nullptr;
|
||||
|
||||
// qwen3a
|
||||
ggml_tensor * conv2d_1_w = nullptr;
|
||||
ggml_tensor * conv2d_1_b = nullptr;
|
||||
ggml_tensor * conv2d_2_w = nullptr;
|
||||
ggml_tensor * conv2d_2_b = nullptr;
|
||||
ggml_tensor * conv2d_3_w = nullptr;
|
||||
ggml_tensor * conv2d_3_b = nullptr;
|
||||
|
||||
// cogvlm
|
||||
ggml_tensor * mm_post_fc_norm_w = nullptr;
|
||||
ggml_tensor * mm_post_fc_norm_b = nullptr;
|
||||
@@ -459,6 +476,15 @@ struct clip_model {
|
||||
};
|
||||
std::map<std::string, clamp_info> clamp_info_map;
|
||||
|
||||
// gemma4 audio conformer
|
||||
std::array<ggml_tensor *, 2> sscp_conv_w = {nullptr};
|
||||
std::array<ggml_tensor *, 2> sscp_conv_b = {nullptr};
|
||||
std::array<ggml_tensor *, 2> sscp_norm_w = {nullptr};
|
||||
ggml_tensor * sscp_inp_proj_w = nullptr;
|
||||
ggml_tensor * sscp_inp_proj_b = nullptr;
|
||||
ggml_tensor * audio_out_proj_w = nullptr;
|
||||
ggml_tensor * audio_out_proj_b = nullptr;
|
||||
|
||||
bool audio_has_avgpool() const {
|
||||
return proj_type == PROJECTOR_TYPE_QWEN2A
|
||||
|| proj_type == PROJECTOR_TYPE_VOXTRAL
|
||||
|
||||
@@ -931,10 +931,18 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
{
|
||||
builder = std::make_unique<clip_graph_conformer>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_gemma4a>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GLM4V:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_glm4v>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_qwen3a>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_YOUTUVL:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_youtuvl>(ctx, img);
|
||||
@@ -1398,6 +1406,7 @@ struct clip_model_loader {
|
||||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
@@ -1459,6 +1468,16 @@ struct clip_model_loader {
|
||||
hparams.audio_window_len = 400;
|
||||
hparams.audio_hop_len = 160;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
// Gemma4 feature_extraction_gemma4.py:
|
||||
// frame_length_ms=20 -> 320 samples, n_fft=512, hop=10ms -> 160
|
||||
hparams.audio_chunk_len = 0; // no fixed-length padding
|
||||
hparams.audio_sample_rate = 16000;
|
||||
hparams.audio_n_fft = 512;
|
||||
hparams.audio_window_len = 320; // 20ms frame (NOT 25ms/400)
|
||||
hparams.audio_hop_len = 160;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
{
|
||||
hparams.image_pad_color = {127, 127, 127};
|
||||
@@ -1561,16 +1580,21 @@ struct clip_model_loader {
|
||||
}
|
||||
|
||||
// helper function
|
||||
std::unordered_set<std::string> loaded_tensor_names;
|
||||
auto get_tensor = [&](const std::string & name, bool required = true) {
|
||||
// Each tensor should only be loaded once; duplicates indicate a bug
|
||||
if (loaded_tensor_names.count(name)) {
|
||||
throw std::runtime_error(string_format("%s: tensor already loaded: %s\n", __func__, name.c_str()));
|
||||
}
|
||||
ggml_tensor * cur = ggml_get_tensor(ctx_meta.get(), name.c_str());
|
||||
if (!cur && required) {
|
||||
throw std::runtime_error(string_format("%s: unable to find tensor %s\n", __func__, name.c_str()));
|
||||
}
|
||||
if (cur) {
|
||||
tensors_to_load.push_back(cur);
|
||||
// add tensors to context
|
||||
ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur);
|
||||
ggml_set_name(data_tensor, cur->name);
|
||||
loaded_tensor_names.insert(name);
|
||||
cur = data_tensor;
|
||||
}
|
||||
return cur;
|
||||
@@ -2053,6 +2077,20 @@ struct clip_model_loader {
|
||||
model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
|
||||
model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
{
|
||||
model.conv2d_1_w = get_tensor(string_format(TN_CONV2D, 1, "weight"));
|
||||
model.conv2d_1_b = get_tensor(string_format(TN_CONV2D, 1, "bias"));
|
||||
model.conv2d_2_w = get_tensor(string_format(TN_CONV2D, 2, "weight"));
|
||||
model.conv2d_2_b = get_tensor(string_format(TN_CONV2D, 2, "bias"));
|
||||
model.conv2d_3_w = get_tensor(string_format(TN_CONV2D, 3, "weight"));
|
||||
model.conv2d_3_b = get_tensor(string_format(TN_CONV2D, 3, "bias"));
|
||||
model.conv_out_w = get_tensor(string_format(TN_CONV_OUT, "weight")); // no bias
|
||||
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
|
||||
model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias"));
|
||||
model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
|
||||
model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
{
|
||||
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
@@ -2186,6 +2224,76 @@ struct clip_model_loader {
|
||||
model.mm_fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
|
||||
model.mm_fc_b = get_tensor(string_format(TN_MM_PROJECTOR, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
for (int i = 0; i < 2; i++) {
|
||||
model.sscp_conv_w[i] = get_tensor(string_format(TN_A_CONV1D, i, "weight"));
|
||||
model.sscp_conv_b[i] = get_tensor(string_format(TN_A_CONV1D, i, "bias"), false);
|
||||
model.sscp_norm_w[i] = get_tensor(string_format(TN_A_CONV1D_NORM, i, "weight"), false);
|
||||
}
|
||||
model.sscp_inp_proj_w = get_tensor(string_format(TN_A_INP_PROJ, "weight"));
|
||||
model.sscp_inp_proj_b = get_tensor(string_format(TN_A_INP_PROJ, "bias"), false);
|
||||
model.audio_out_proj_w = get_tensor(string_format(TN_A_OUT_PROJ, "weight"), false);
|
||||
model.audio_out_proj_b = get_tensor(string_format(TN_A_OUT_PROJ, "bias"), false);
|
||||
// audio multimodal embedder (mm.a.* namespace, not mm.*)
|
||||
model.mm_soft_emb_norm_w = get_tensor(string_format(TN_A_MM_SOFT_EMB_N, "weight"), false);
|
||||
model.mm_input_proj_w = get_tensor(string_format(TN_A_MM_INP_PROJ, "weight"), false);
|
||||
|
||||
// Per-layer tensors NOT loaded by the generic loop above
|
||||
for (int il = 0; il < hparams.n_layer; ++il) {
|
||||
auto & layer = model.layers[il];
|
||||
|
||||
// Gemma4 audio conformer-specific tensors
|
||||
layer.ff_norm_w = get_tensor(string_format(TN_FFN_NORM, prefix, il, "weight"));
|
||||
layer.attn_pre_norm_w = get_tensor(string_format(TN_A_ATTN_PRE_NORM, prefix, il, "weight"), false);
|
||||
layer.per_dim_scale_w = get_tensor(string_format(TN_A_PER_DIM_SCALE, prefix, il, "weight"), false);
|
||||
layer.per_dim_k_scale_w = get_tensor(string_format(TN_A_PER_DIM_K_SCALE, prefix, il, "weight"), false);
|
||||
layer.attn_k_rel_w = get_tensor(string_format(TN_A_ATTN_K_REL, prefix, il, "weight"), false);
|
||||
|
||||
// Convolution module
|
||||
// Note: conv_norm / norm_conv are swapped in GGUF due to
|
||||
// upstream tensor_mapping.py, so we load them in reverse order
|
||||
layer.norm_conv_w = get_tensor(string_format(TN_CONV_NORM, prefix, il, "weight"), false);
|
||||
layer.norm_conv_b = get_tensor(string_format(TN_CONV_NORM, prefix, il, "bias"), false);
|
||||
layer.conv_pw1_w = get_tensor(string_format(TN_CONV_PW1, prefix, il, "weight"));
|
||||
layer.conv_pw1_b = get_tensor(string_format(TN_CONV_PW1, prefix, il, "bias"), false);
|
||||
layer.conv_dw_w = get_tensor(string_format(TN_CONV_DW, prefix, il, "weight"));
|
||||
layer.conv_dw_b = get_tensor(string_format(TN_CONV_DW, prefix, il, "bias"), false);
|
||||
layer.conv_norm_w = get_tensor(string_format(TN_NORM_CONV, prefix, il, "weight"), false);
|
||||
layer.conv_norm_b = get_tensor(string_format(TN_NORM_CONV, prefix, il, "bias"), false);
|
||||
layer.conv_pw2_w = get_tensor(string_format(TN_CONV_PW2, prefix, il, "weight"));
|
||||
layer.conv_pw2_b = get_tensor(string_format(TN_CONV_PW2, prefix, il, "bias"), false);
|
||||
|
||||
// FFN2 (second half-step)
|
||||
layer.ff_norm_1_w = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "weight"));
|
||||
layer.ff_up_1_w = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "weight"));
|
||||
layer.ff_up_1_b = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "bias"), false);
|
||||
layer.ff_down_1_w = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "weight"));
|
||||
layer.ff_down_1_b = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "bias"), false);
|
||||
layer.ff_post_norm_1_w = get_tensor(string_format(TN_A_FFN_POST_NORM_1, prefix, il, "weight"), false);
|
||||
}
|
||||
|
||||
// Load clamp info for ClippableLinear AFTER all tensors are loaded
|
||||
for (auto * tensor : tensors_to_load) {
|
||||
std::string name = tensor->name;
|
||||
if (string_ends_with(name, ".weight")) {
|
||||
std::string name_inp_max = name;
|
||||
std::string name_inp_min = name;
|
||||
std::string name_out_max = name;
|
||||
std::string name_out_min = name;
|
||||
string_replace_all(name_inp_max, ".weight", ".input_max");
|
||||
string_replace_all(name_inp_min, ".weight", ".input_min");
|
||||
string_replace_all(name_out_max, ".weight", ".output_max");
|
||||
string_replace_all(name_out_min, ".weight", ".output_min");
|
||||
model.clamp_info_map[name] = {
|
||||
get_scalar(name_inp_max, FLT_MAX),
|
||||
get_scalar(name_inp_min, -FLT_MAX),
|
||||
get_scalar(name_out_max, FLT_MAX),
|
||||
get_scalar(name_out_min, -FLT_MAX)
|
||||
};
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LFM2A:
|
||||
{
|
||||
for (int i : {0, 2, 3, 5, 6}) {
|
||||
@@ -2246,7 +2354,10 @@ struct clip_model_loader {
|
||||
ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||
for (auto & t : tensors_to_load) {
|
||||
ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
|
||||
const size_t offset = tensor_offset[t->name];
|
||||
GGML_ASSERT(cur && "tensor not found in ctx_data");
|
||||
auto it_off = tensor_offset.find(t->name);
|
||||
GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor");
|
||||
const size_t offset = it_off->second;
|
||||
fin.seekg(offset, std::ios::beg);
|
||||
if (!fin) {
|
||||
throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
|
||||
@@ -2266,6 +2377,7 @@ struct clip_model_loader {
|
||||
|
||||
LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
struct support_info_op {
|
||||
@@ -2538,8 +2650,7 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
|
||||
|
||||
// TODO: we don't support audio for Gemma 3N, but GGUF contains audio tensors
|
||||
// we can remove this check when we implement audio support for Gemma 3N
|
||||
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV
|
||||
|| ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA4V;
|
||||
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV;
|
||||
}
|
||||
|
||||
if (loader.has_audio && !skip_audio) {
|
||||
@@ -2856,6 +2967,15 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
n_patches /= 2;
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
{
|
||||
// 3x stride-2 conv2d: each step is floor((n-1)/2)+1
|
||||
int n = img->nx;
|
||||
n = (n - 1) / 2 + 1;
|
||||
n = (n - 1) / 2 + 1;
|
||||
n = (n - 1) / 2 + 1;
|
||||
n_patches = n;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
{
|
||||
n_patches = img->nx;
|
||||
@@ -2893,6 +3013,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
{
|
||||
n_patches = ((((img->nx + 1) / 2) + 1) / 2 + 1) / 2;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
// Two Conv2D stride-2: O = floor((I + 2p - k) / s) + 1, p=1, k=3, s=2
|
||||
// O = floor((I - 1) / 2) + 1
|
||||
int n = img->nx;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
n = (n - 1) / 2 + 1;
|
||||
}
|
||||
n_patches = n;
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported projector type");
|
||||
}
|
||||
@@ -3322,6 +3452,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
@@ -3352,6 +3483,56 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
}
|
||||
set_input_i32("pos_w", pos_data);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
GGML_ASSERT(imgs.entries.size() == 1);
|
||||
const auto & img0 = imgs.entries.front();
|
||||
// Compute n_pos matching SSCP output: two stride-2 convs
|
||||
int n_pos = img0->nx;
|
||||
for (int i = 0; i < 2; i++) { n_pos = (n_pos - 1) / 2 + 1; }
|
||||
|
||||
// Chunked local attention: blocked causal mask and RPE
|
||||
const int chunk_size = 12;
|
||||
const int max_past = 12;
|
||||
const int context_size = chunk_size + max_past;
|
||||
const int num_blocks = (n_pos + chunk_size - 1) / chunk_size;
|
||||
|
||||
// Blocked causal attention mask: [context_size, chunk_size, num_blocks]
|
||||
{
|
||||
std::vector<float> mask(context_size * chunk_size * num_blocks, -1e9f);
|
||||
for (int b = 0; b < num_blocks; b++) {
|
||||
for (int q = 0; q < chunk_size; q++) {
|
||||
int gq = b * chunk_size + q;
|
||||
for (int k = 0; k < context_size; k++) {
|
||||
int gk = b * chunk_size - max_past + k;
|
||||
if (gq < n_pos && gk >= 0 && gk < n_pos && gk <= gq && (gq - gk) < max_past) {
|
||||
mask[k + q * context_size + b * context_size * chunk_size] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
set_input_f32("kq_mask", mask);
|
||||
}
|
||||
|
||||
// Sinusoidal RPE: 13 positions [12, 11, ..., 0]
|
||||
{
|
||||
const int n_embd = ctx->model.hparams.n_embd;
|
||||
const int num_timescales = n_embd / 2;
|
||||
const float log_timescale_increment = logf(10000.0f) / std::max(num_timescales - 1, 1);
|
||||
const int rpe_len = max_past + 1;
|
||||
std::vector<float> pos_emb(n_embd * rpe_len, 0.0f);
|
||||
for (int p = 0; p < rpe_len; p++) {
|
||||
float position = (float)(max_past - p);
|
||||
for (int i = 0; i < num_timescales; i++) {
|
||||
float inv_ts = expf(-(float)i * log_timescale_increment);
|
||||
float scaled = position * inv_ts;
|
||||
pos_emb[p * n_embd + i] = sinf(scaled);
|
||||
pos_emb[p * n_embd + i + num_timescales] = cosf(scaled);
|
||||
}
|
||||
}
|
||||
set_input_f32("pos_emb", pos_emb);
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LFM2A:
|
||||
{
|
||||
GGML_ASSERT(imgs.entries.size() == 1);
|
||||
@@ -3501,8 +3682,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
return ctx->model.mm_model_proj->ne[1];
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
return ctx->model.mm_fc_w->ne[1];
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
case PROJECTOR_TYPE_PADDLEOCR:
|
||||
@@ -3516,6 +3698,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
return ctx->model.mm_fc_w->ne[1];
|
||||
case PROJECTOR_TYPE_LFM2A:
|
||||
return ctx->model.position_embeddings->ne[0];
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
return ctx->model.hparams.projection_dim;
|
||||
case PROJECTOR_TYPE_GLM4V:
|
||||
return ctx->model.mm_ffn_down_w->ne[1];
|
||||
default:
|
||||
@@ -3552,6 +3736,7 @@ bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
|
||||
switch (ctx->proj_type()) {
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
|
||||
288
tools/mtmd/models/gemma4a.cpp
Normal file
288
tools/mtmd/models/gemma4a.cpp
Normal file
@@ -0,0 +1,288 @@
|
||||
/**
|
||||
* Gemma 4 Audio Conformer Encoder (clip_graph_gemma4a)
|
||||
*
|
||||
* Architecture: Conformer with dual half-step FFN, full self-attention
|
||||
* with sinusoidal RPE, depthwise light conv, and output projection.
|
||||
*/
|
||||
|
||||
#include "models.h"
|
||||
#include <cmath>
|
||||
|
||||
ggml_cgraph * clip_graph_gemma4a::build() {
|
||||
const float res_weight = 0.5f;
|
||||
const float norm_eps = 1e-6f;
|
||||
|
||||
// 1. Input
|
||||
ggml_tensor * inp = build_inp_raw(1);
|
||||
auto * cur = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
||||
|
||||
// 2. Subsampling Conv2D (symmetric padding=1, matching PyTorch)
|
||||
{
|
||||
for (int i = 0; i < 2; i++) {
|
||||
cur = ggml_conv_2d(ctx0, model.sscp_conv_w[i], cur, 2, 2, 1, 1, 1, 1);
|
||||
if (model.sscp_conv_b[i]) {
|
||||
cur = ggml_add(ctx0, cur, model.sscp_conv_b[i]);
|
||||
}
|
||||
// nn.LayerNorm(channels): permute ch to ne[0], normalize, permute back
|
||||
if (model.sscp_norm_w[i]) {
|
||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
|
||||
cur = ggml_norm(ctx0, cur, norm_eps);
|
||||
cur = ggml_mul(ctx0, cur, model.sscp_norm_w[i]);
|
||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
|
||||
}
|
||||
cur = ggml_relu(ctx0, cur);
|
||||
}
|
||||
// Flatten [freq, time, ch, 1] -> [ch*freq, time]
|
||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]);
|
||||
if (model.sscp_inp_proj_w) {
|
||||
cur = build_mm(model.sscp_inp_proj_w, cur);
|
||||
if (model.sscp_inp_proj_b) {
|
||||
cur = ggml_add(ctx0, cur, model.sscp_inp_proj_b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t n_pos = cur->ne[1];
|
||||
|
||||
// Chunked local attention parameters
|
||||
const int64_t C = 12; // chunk_size
|
||||
const int64_t P = 12; // max_past_horizon (context_left - 1)
|
||||
const int64_t S = C + P; // context_size = 24
|
||||
const int64_t R = P + 1; // RPE positions = 13
|
||||
const int64_t B = (n_pos + C - 1) / C; // num_blocks
|
||||
const int64_t Np = B * C; // padded sequence length
|
||||
const int64_t pad_seq = Np - n_pos;
|
||||
|
||||
// Input tensors: blocked RPE and blocked attention mask
|
||||
ggml_tensor * pos_emb = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_head * d_head, R);
|
||||
ggml_set_name(pos_emb, "pos_emb");
|
||||
ggml_set_input(pos_emb);
|
||||
|
||||
ggml_tensor * kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, S, C, B);
|
||||
ggml_set_name(kq_mask, "kq_mask");
|
||||
ggml_set_input(kq_mask);
|
||||
|
||||
// 3. Conformer Blocks
|
||||
for (int il = 0; il < hparams.n_layer; il++) {
|
||||
const auto & layer = model.layers[il];
|
||||
auto * residual = cur;
|
||||
|
||||
// FFN 1 (half-step)
|
||||
if (layer.ff_norm_w && layer.ff_up_w && layer.ff_down_w) {
|
||||
cur = build_norm(cur, layer.ff_norm_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
|
||||
cur = build_ffn(cur,
|
||||
layer.ff_up_w, nullptr, nullptr, nullptr,
|
||||
layer.ff_down_w, nullptr, FFN_SILU, il);
|
||||
if (layer.ff_post_norm_w) {
|
||||
cur = build_norm(cur, layer.ff_post_norm_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
|
||||
}
|
||||
residual = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, res_weight));
|
||||
}
|
||||
|
||||
// Chunked local self-attention with RPE
|
||||
if (layer.q_w && layer.k_w && layer.v_w && layer.o_w) {
|
||||
const float q_scale = (1.0f / sqrtf((float)d_head)) / logf(2.0f);
|
||||
const float k_scale = logf(1.0f + expf(1.0f)) / logf(2.0f);
|
||||
const float softcap = 50.0f;
|
||||
|
||||
ggml_tensor * attn_norm_w = layer.attn_pre_norm_w ? layer.attn_pre_norm_w : layer.ln_1_w;
|
||||
cur = attn_norm_w
|
||||
? build_norm(residual, attn_norm_w, nullptr, NORM_TYPE_RMS, norm_eps, il)
|
||||
: residual;
|
||||
|
||||
ggml_tensor * Qcur = build_mm(layer.q_w, cur);
|
||||
ggml_tensor * Kcur = build_mm(layer.k_w, cur);
|
||||
ggml_tensor * Vcur = build_mm(layer.v_w, cur);
|
||||
|
||||
// [n_embd, n_pos] -> [D, H, N]
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
|
||||
|
||||
// Q/K scaling
|
||||
Qcur = ggml_scale(ctx0, Qcur, q_scale);
|
||||
if (layer.per_dim_scale_w) {
|
||||
Qcur = ggml_mul(ctx0, Qcur, ggml_reshape_3d(ctx0, layer.per_dim_scale_w, d_head, 1, 1));
|
||||
}
|
||||
Kcur = ggml_scale(ctx0, Kcur, k_scale);
|
||||
if (layer.per_dim_k_scale_w) {
|
||||
Kcur = ggml_mul(ctx0, Kcur, ggml_reshape_3d(ctx0, layer.per_dim_k_scale_w, d_head, 1, 1));
|
||||
}
|
||||
|
||||
// Q blocking: [D, H, N] -> pad to Np -> reshape [D, H, C, B]
|
||||
// ggml permute: ne[ax_i] = src->ne[i], so (0,3,1,2) sends H->3, C->1, B->2
|
||||
Qcur = ggml_pad(ctx0, Qcur, 0, 0, pad_seq, 0); // [D, H, Np]
|
||||
Qcur = ggml_reshape_4d(ctx0, Qcur, d_head, n_head, C, B); // [D, H, C, B]
|
||||
Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 3, 1, 2)); // [D, C, B, H]
|
||||
|
||||
// K/V block context extraction via overlapping view:
|
||||
// Pad to S*B elements, roll right by P to create left-padding,
|
||||
// then view with stride C in the block dimension (overlapping windows).
|
||||
auto extract_blocks = [&](ggml_tensor * t) -> ggml_tensor * {
|
||||
// [D, H, N] -> pad to S*B -> roll right by P -> cont (materialize)
|
||||
const int64_t pad_kv = S * B - n_pos;
|
||||
t = ggml_pad(ctx0, t, 0, 0, pad_kv, 0); // [D, H, S*B]
|
||||
t = ggml_roll(ctx0, t, 0, 0, P, 0); // left-pad by P
|
||||
t = ggml_cont(ctx0, t); // materialize roll (removes view offset)
|
||||
// Overlapping view: stride for B dim is C positions, not S
|
||||
// ne = [D, H, S, B], data_size = D*H*S*B*sizeof = source_nbytes (exact fit)
|
||||
// nb1=D*sizeof, nb2=D*H*sizeof, nb3=C*D*H*sizeof (overlap: C < S)
|
||||
t = ggml_view_4d(ctx0, t, d_head, n_head, S, B,
|
||||
t->nb[1], t->nb[2], C * t->nb[2], 0);
|
||||
t = ggml_cont(ctx0, t); // materialize overlapping windows
|
||||
return t;
|
||||
};
|
||||
|
||||
ggml_tensor * Kblk = extract_blocks(Kcur);
|
||||
// [D, H, S, B] -> [D, S, B, H] via permute(0,3,1,2)
|
||||
Kblk = ggml_cont(ctx0, ggml_permute(ctx0, Kblk, 0, 3, 1, 2));
|
||||
|
||||
ggml_tensor * Vblk = extract_blocks(Vcur);
|
||||
// [D, H, S, B] -> [S, D, B, H] via permute(1,3,0,2)
|
||||
Vblk = ggml_cont(ctx0, ggml_permute(ctx0, Vblk, 1, 3, 0, 2));
|
||||
|
||||
// Content attention: Q @ K^T
|
||||
// Kblk=[D,S,B,H], Qcur=[D,C,B,H] -> mul_mat contracts on D -> [S,C,B,H]
|
||||
ggml_tensor * matrix_ac = ggml_mul_mat(ctx0, Kblk, Qcur);
|
||||
|
||||
// Relative position attention
|
||||
if (layer.attn_k_rel_w) {
|
||||
// RPE: [n_embd, R] -> project -> [D, H, R] -> [D, R, H]
|
||||
auto * p = ggml_mul_mat(ctx0, layer.attn_k_rel_w, pos_emb);
|
||||
p = ggml_reshape_3d(ctx0, p, d_head, n_head, R);
|
||||
p = ggml_cont(ctx0, ggml_permute(ctx0, p, 0, 2, 1, 3)); // [D, R, H]
|
||||
|
||||
// Q_flat @ RPE^T: [D, C*B, H] @ [D, R, H] -> [R, C*B, H]
|
||||
auto * Q_flat = ggml_reshape_3d(ctx0, Qcur, d_head, C * B, n_head);
|
||||
auto * matrix_bd = ggml_mul_mat(ctx0, p, Q_flat); // [R, C*B, H]
|
||||
matrix_bd = ggml_reshape_4d(ctx0, matrix_bd, R, C, B, n_head); // [R, C, B, H]
|
||||
|
||||
// Blocked relative shift (appendix B of Transformer-XL)
|
||||
{
|
||||
matrix_bd = ggml_pad(ctx0, matrix_bd, S + 1 - R, 0, 0, 0); // [S+1, C, B, H]
|
||||
matrix_bd = ggml_reshape_3d(ctx0, matrix_bd, (S + 1) * C, B, n_head);
|
||||
matrix_bd = ggml_view_3d(ctx0, matrix_bd,
|
||||
C * S, B, n_head,
|
||||
matrix_bd->nb[1], matrix_bd->nb[2], 0);
|
||||
matrix_bd = ggml_cont(ctx0, matrix_bd); // [C*S, B, H]
|
||||
matrix_bd = ggml_reshape_4d(ctx0, matrix_bd, S, C, B, n_head); // [S, C, B, H]
|
||||
}
|
||||
|
||||
matrix_ac = ggml_add(ctx0, matrix_ac, matrix_bd);
|
||||
}
|
||||
|
||||
auto * scores = matrix_ac; // [S, C, B, H]
|
||||
|
||||
// Softcap
|
||||
scores = ggml_scale(ctx0, scores, 1.0f / softcap);
|
||||
scores = ggml_tanh(ctx0, scores);
|
||||
scores = ggml_scale(ctx0, scores, softcap);
|
||||
|
||||
// Blocked attention mask: [S, C, B] broadcasts over H
|
||||
scores = ggml_add(ctx0, scores, kq_mask);
|
||||
|
||||
ggml_tensor * attn = ggml_soft_max(ctx0, scores);
|
||||
|
||||
// attn @ V: [S,C,B,H] @ [S,D,B,H] -> [D,C,B,H]
|
||||
ggml_tensor * x = ggml_mul_mat(ctx0, Vblk, attn);
|
||||
|
||||
// [D,C,B,H] -> [D,H,C,B] via permute(0,2,3,1) -> flatten -> trim
|
||||
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 3, 1));
|
||||
x = ggml_cont_2d(ctx0, x, d_head * n_head, C * B);
|
||||
if (pad_seq > 0) {
|
||||
x = ggml_view_2d(ctx0, x, d_head * n_head, n_pos, x->nb[1], 0);
|
||||
x = ggml_cont(ctx0, x);
|
||||
}
|
||||
|
||||
x = build_mm(layer.o_w, x);
|
||||
if (layer.o_b) { x = ggml_add(ctx0, x, layer.o_b); }
|
||||
|
||||
if (layer.attn_post_norm_w) {
|
||||
x = build_norm(x, layer.attn_post_norm_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
|
||||
}
|
||||
residual = ggml_add(ctx0, residual, x);
|
||||
}
|
||||
|
||||
// Convolution Module
|
||||
if (layer.norm_conv_w && layer.conv_pw1_w && layer.conv_dw_w && layer.conv_pw2_w) {
|
||||
cur = build_norm(residual, layer.norm_conv_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
|
||||
auto * x = build_mm(layer.conv_pw1_w, cur);
|
||||
|
||||
// GLU
|
||||
{
|
||||
int64_t d = x->ne[0] / 2;
|
||||
ggml_tensor * gate = ggml_sigmoid(ctx0,
|
||||
ggml_cont(ctx0, ggml_view_2d(ctx0, x, d, x->ne[1], x->nb[1], d * x->nb[0])));
|
||||
x = ggml_mul(ctx0,
|
||||
ggml_view_2d(ctx0, x, d, x->ne[1], x->nb[1], 0), gate);
|
||||
x = ggml_cont(ctx0, ggml_transpose(ctx0, x));
|
||||
}
|
||||
|
||||
// Causal depthwise Conv1D via ggml_ssm_conv (pad+roll for left-only padding).
|
||||
x = ggml_pad(ctx0, x, 4, 0, 0, 0);
|
||||
x = ggml_roll(ctx0, x, 4, 0, 0, 0);
|
||||
x = ggml_ssm_conv(ctx0, x, layer.conv_dw_w);
|
||||
if (layer.conv_dw_b) {
|
||||
x = ggml_add(ctx0, x, layer.conv_dw_b);
|
||||
}
|
||||
|
||||
if (layer.conv_norm_w) {
|
||||
x = ggml_rms_norm(ctx0, x, norm_eps);
|
||||
x = ggml_mul(ctx0, x, layer.conv_norm_w);
|
||||
}
|
||||
x = ggml_silu(ctx0, x);
|
||||
x = build_mm(layer.conv_pw2_w, x);
|
||||
residual = ggml_add(ctx0, residual, x);
|
||||
}
|
||||
|
||||
// FFN 2 (half-step)
|
||||
if (layer.ff_norm_1_w && layer.ff_up_1_w && layer.ff_down_1_w) {
|
||||
cur = build_norm(residual, layer.ff_norm_1_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
|
||||
cur = build_ffn(cur,
|
||||
layer.ff_up_1_w, nullptr, nullptr, nullptr,
|
||||
layer.ff_down_1_w, nullptr, FFN_SILU, il);
|
||||
if (layer.ff_post_norm_1_w) {
|
||||
cur = build_norm(cur, layer.ff_post_norm_1_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
|
||||
}
|
||||
residual = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, res_weight));
|
||||
}
|
||||
|
||||
// Layer output norm
|
||||
cur = layer.ln_2_w
|
||||
? build_norm(residual, layer.ln_2_w, nullptr, NORM_TYPE_RMS, norm_eps, il)
|
||||
: residual;
|
||||
|
||||
}
|
||||
|
||||
// 4. Output Projection
|
||||
if (model.audio_out_proj_w) {
|
||||
cur = build_mm(model.audio_out_proj_w, cur);
|
||||
if (model.audio_out_proj_b) {
|
||||
cur = ggml_add(ctx0, cur, model.audio_out_proj_b);
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Audio Multimodal Embedder
|
||||
cur = ggml_rms_norm(ctx0, cur, norm_eps);
|
||||
if (model.mm_soft_emb_norm_w) {
|
||||
cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w);
|
||||
}
|
||||
if (model.mm_input_proj_w) {
|
||||
cur = build_mm(model.mm_input_proj_w, cur);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
return gf;
|
||||
}
|
||||
|
||||
ggml_tensor * clip_graph_gemma4a::build_mm(ggml_tensor * w, ggml_tensor * x) const {
|
||||
auto it = model.clamp_info_map.find(w->name);
|
||||
if (it == model.clamp_info_map.end()) {
|
||||
return ggml_mul_mat(ctx0, w, x);
|
||||
}
|
||||
const auto & ci = it->second;
|
||||
ggml_tensor * clamped = ggml_clamp(ctx0, x, ci.inp_min, ci.inp_max);
|
||||
ggml_tensor * out = ggml_mul_mat(ctx0, w, clamped);
|
||||
return ggml_clamp(ctx0, out, ci.out_min, ci.out_max);
|
||||
}
|
||||
@@ -103,6 +103,12 @@ struct clip_graph_conformer : clip_graph {
|
||||
ggml_cgraph * build() override;
|
||||
};
|
||||
|
||||
struct clip_graph_gemma4a : clip_graph {
|
||||
clip_graph_gemma4a(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
ggml_tensor * build_mm(ggml_tensor * w, ggml_tensor * x) const override;
|
||||
};
|
||||
|
||||
struct clip_graph_glm4v : clip_graph {
|
||||
clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
@@ -146,6 +152,11 @@ struct clip_graph_mobilenetv5 : clip_graph {
|
||||
const mobilenetv5_block & block);
|
||||
};
|
||||
|
||||
struct clip_graph_qwen3a : clip_graph {
|
||||
clip_graph_qwen3a(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
};
|
||||
|
||||
struct clip_graph_kimik25 : clip_graph {
|
||||
clip_graph_kimik25(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
|
||||
68
tools/mtmd/models/qwen3a.cpp
Normal file
68
tools/mtmd/models/qwen3a.cpp
Normal file
@@ -0,0 +1,68 @@
|
||||
#include "models.h"
|
||||
|
||||
ggml_cgraph * clip_graph_qwen3a::build() {
|
||||
ggml_tensor * inp = build_inp_raw(1);
|
||||
|
||||
// conv2d block
|
||||
// TODO: do we need to split by chunks of n_window each like on transformers impl?
|
||||
{
|
||||
inp = ggml_conv_2d(ctx0, model.conv2d_1_w, inp, 2, 2, 1, 1, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, model.conv2d_1_b);
|
||||
inp = ggml_gelu_erf(ctx0, inp);
|
||||
|
||||
inp = ggml_conv_2d(ctx0, model.conv2d_2_w, inp, 2, 2, 1, 1, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, model.conv2d_2_b);
|
||||
inp = ggml_gelu_erf(ctx0, inp);
|
||||
|
||||
inp = ggml_conv_2d(ctx0, model.conv2d_3_w, inp, 2, 2, 1, 1, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, model.conv2d_3_b);
|
||||
inp = ggml_gelu_erf(ctx0, inp);
|
||||
|
||||
// inp [n_pos, n_mels/8, channels, 1] (W, H, C, N)
|
||||
cb(inp, "after_conv_blocks", -1);
|
||||
|
||||
const int64_t n_pos_after_conv = inp->ne[0];
|
||||
const int64_t n_mel_after_conv = inp->ne[1]; // 128/8 = 16
|
||||
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 3, 1));
|
||||
inp = ggml_reshape_2d(ctx0, inp, n_pos_after_conv, n_mel_after_conv * inp->ne[3]); // [n_pos, 7680]
|
||||
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // [7680, n_pos]
|
||||
|
||||
// project to n_embd
|
||||
inp = ggml_mul_mat(ctx0, model.conv_out_w, inp);
|
||||
if (model.conv_out_b) {
|
||||
inp = ggml_add(ctx0, inp, model.conv_out_b);
|
||||
}
|
||||
cb(inp, "after_conv_out", -1);
|
||||
}
|
||||
|
||||
auto n_pos = inp->ne[1];
|
||||
|
||||
ggml_tensor * pos_embd_selected = ggml_view_2d(
|
||||
ctx0, model.position_embeddings,
|
||||
model.position_embeddings->ne[0], n_pos,
|
||||
model.position_embeddings->nb[1], 0
|
||||
);
|
||||
ggml_tensor * cur = build_vit(
|
||||
inp, n_pos,
|
||||
NORM_TYPE_NORMAL,
|
||||
hparams.ffn_op,
|
||||
pos_embd_selected,
|
||||
nullptr);
|
||||
|
||||
cb(cur, "after_transformer", -1);
|
||||
|
||||
// projector
|
||||
cur = build_ffn(cur,
|
||||
model.mm_1_w, model.mm_1_b,
|
||||
nullptr, nullptr,
|
||||
model.mm_2_w, model.mm_2_b,
|
||||
FFN_GELU_ERF,
|
||||
-1);
|
||||
|
||||
cb(cur, "projected", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
// some of the code here is copied from whisper.cpp
|
||||
|
||||
@@ -37,23 +38,36 @@ void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel,
|
||||
float fmin,
|
||||
float fmax,
|
||||
bool slaney_area_norm,
|
||||
float scale) {
|
||||
float scale,
|
||||
bool use_htk) {
|
||||
GGML_ASSERT(n_mel > 0 && n_fft > 1);
|
||||
if (fmax <= 0.0f) {
|
||||
fmax = 0.5f * sample_rate;
|
||||
}
|
||||
|
||||
// Slaney scale (matches librosa default)
|
||||
const double min_log_hz = 1000.0;
|
||||
const double lin_slope = 3 / 200.;
|
||||
const double min_log_mel = min_log_hz * lin_slope;
|
||||
const double log_step = log(6.4) / 27.0;
|
||||
auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
|
||||
return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
|
||||
};
|
||||
auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
|
||||
return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
|
||||
};
|
||||
std::function<double(double)> hz_to_mel;
|
||||
std::function<double(double)> mel_to_hz;
|
||||
|
||||
if (use_htk) {
|
||||
hz_to_mel = [](const double f_hz) -> double {
|
||||
return 2595.0 * log10(1.0 + f_hz / 700.0);
|
||||
};
|
||||
mel_to_hz = [](const double m) -> double {
|
||||
return 700.0 * (pow(10.0, m / 2595.0) - 1.0);
|
||||
};
|
||||
} else {
|
||||
// Slaney scale (matches librosa default)
|
||||
const double min_log_hz = 1000.0;
|
||||
const double lin_slope = 3 / 200.;
|
||||
const double min_log_mel = min_log_hz * lin_slope;
|
||||
const double log_step = log(6.4) / 27.0;
|
||||
hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
|
||||
return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
|
||||
};
|
||||
mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
|
||||
return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
|
||||
};
|
||||
}
|
||||
|
||||
// infer N_fft from n_fft_bins
|
||||
const double bin_hz_step = double(sample_rate) / double(n_fft);
|
||||
@@ -257,10 +271,13 @@ struct filter_params {
|
||||
int32_t hann_window_size;
|
||||
int32_t hop_length;
|
||||
int32_t sample_rate;
|
||||
bool center_padding = false;
|
||||
float preemph = 0.f;
|
||||
bool no_padding = false;
|
||||
bool center_padding = false;
|
||||
float preemph = 0.f;
|
||||
bool use_natural_log = false;
|
||||
bool norm_per_feature = false;
|
||||
bool use_magnitude = false; // |X| instead of |X|^2
|
||||
float mel_floor = 5.960464477539063e-08f;
|
||||
};
|
||||
|
||||
static void log_mel_spectrogram_worker_thread(int ith,
|
||||
@@ -301,10 +318,10 @@ static void log_mel_spectrogram_worker_thread(int ith,
|
||||
// FFT
|
||||
fft(cache, fft_in.data(), frame_size, fft_out.data());
|
||||
|
||||
// Calculate modulus^2 of complex numbers
|
||||
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|
||||
// Calculate modulus^2 (power) or modulus (magnitude)
|
||||
for (int j = 0; j < n_fft_bins; j++) {
|
||||
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
|
||||
float power = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
|
||||
fft_out[j] = params.use_magnitude ? sqrtf(power) : power;
|
||||
}
|
||||
|
||||
// mel spectrogram
|
||||
@@ -324,9 +341,10 @@ static void log_mel_spectrogram_worker_thread(int ith,
|
||||
for (; k < n_fft_bins; k++) {
|
||||
sum += fft_out[k] * filters.data[j * n_fft_bins + k];
|
||||
}
|
||||
sum = std::max(sum, (double)params.mel_floor);
|
||||
sum = params.use_natural_log
|
||||
? log(sum + 5.960464477539063e-08)
|
||||
: log10(std::max(sum, 1e-10));
|
||||
? log(sum)
|
||||
: log10(sum);
|
||||
out.data[j * out.n_len + i] = sum;
|
||||
}
|
||||
}
|
||||
@@ -360,7 +378,12 @@ static bool log_mel_spectrogram(
|
||||
|
||||
// Padding
|
||||
std::vector<float> samples_padded;
|
||||
if (params.center_padding) {
|
||||
if (params.no_padding) {
|
||||
// no padding, use samples as-is
|
||||
samples_padded = std::vector<float>(samples, samples + n_samples);
|
||||
samples = samples_padded.data();
|
||||
n_samples = samples_padded.size();
|
||||
} else if (params.center_padding) {
|
||||
const auto pad_amount = frame_size / 2;
|
||||
samples_padded = std::vector<float>(n_samples + 2 * pad_amount, 0);
|
||||
std::copy(samples, samples + n_samples, samples_padded.data() + pad_amount);
|
||||
@@ -464,8 +487,8 @@ static bool log_mel_spectrogram(
|
||||
out.data[i * out.n_len + j] = 0.0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// clamping and normalization
|
||||
} else if (!params.no_padding) {
|
||||
// Whisper-style clamping and normalization (NOT used by Gemma4)
|
||||
double mmax = -1e20;
|
||||
for (int i = 0; i < out.n_mel*out.n_len; i++) {
|
||||
if (out.data[i] > mmax) {
|
||||
@@ -627,6 +650,87 @@ bool mtmd_audio_preprocessor_conformer::preprocess(const float *
|
||||
return true;
|
||||
}
|
||||
|
||||
//
|
||||
// mtmd_audio_preprocessor_gemma4a
|
||||
//
|
||||
|
||||
void mtmd_audio_preprocessor_gemma4a::initialize() {
|
||||
cache.fill_sin_cos_table(hparams.audio_n_fft);
|
||||
|
||||
// Standard periodic Hann window, zero-padded to FFT size
|
||||
cache.hann_window.assign(hparams.audio_n_fft, 0.0f);
|
||||
for (uint32_t i = 0; i < (uint32_t)hparams.audio_window_len; i++) {
|
||||
cache.hann_window[i] = 0.5f - 0.5f * cosf((2.0f * (float)M_PI * i) / hparams.audio_window_len);
|
||||
}
|
||||
|
||||
// HTK mel scale, no Slaney area normalization
|
||||
cache.fill_mel_filterbank_matrix(
|
||||
hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate,
|
||||
0.0f, hparams.audio_sample_rate / 2.0f,
|
||||
/*slaney_area_norm=*/ false,
|
||||
/*scale=*/ 1.0f,
|
||||
/*use_htk=*/ true
|
||||
);
|
||||
}
|
||||
|
||||
bool mtmd_audio_preprocessor_gemma4a::preprocess(const float * samples,
|
||||
size_t n_samples,
|
||||
std::vector<mtmd_audio_mel> & output) {
|
||||
if (n_samples == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_ASSERT(!cache.sin_vals.empty());
|
||||
GGML_ASSERT(!cache.cos_vals.empty());
|
||||
GGML_ASSERT(!cache.filters.data.empty());
|
||||
|
||||
filter_params params;
|
||||
params.n_mel = hparams.n_mel_bins;
|
||||
params.n_fft_bins = 1 + (hparams.audio_n_fft / 2);
|
||||
params.hann_window_size = hparams.audio_n_fft; // window is zero-padded to FFT size
|
||||
params.hop_length = hparams.audio_hop_len;
|
||||
params.sample_rate = hparams.audio_sample_rate;
|
||||
params.no_padding = true;
|
||||
params.center_padding = false;
|
||||
params.preemph = 0.0f;
|
||||
params.use_natural_log = true;
|
||||
params.use_magnitude = true;
|
||||
params.mel_floor = 0.001f;
|
||||
params.norm_per_feature = false;
|
||||
|
||||
// Split into 30-second chunks (model context limit, ~750 tokens each)
|
||||
const size_t chunk_samples = 30 * hparams.audio_sample_rate;
|
||||
for (size_t off = 0; off < n_samples; off += chunk_samples) {
|
||||
const float * chunk_ptr = samples + off;
|
||||
size_t chunk_len = std::min(chunk_samples, n_samples - off);
|
||||
|
||||
// Semicausal left-padding + right-padding to match PyTorch frame count
|
||||
const int pad_left = hparams.audio_window_len / 2;
|
||||
const int fft_size = hparams.audio_n_fft;
|
||||
const int hop = hparams.audio_hop_len;
|
||||
const int n_with_left = (int)chunk_len + pad_left;
|
||||
// PyTorch: unfold(size=frame_length+1, step=hop) on semicausal-padded waveform
|
||||
const int pt_frames = (n_with_left - (hparams.audio_window_len + 1)) / hop + 1;
|
||||
const int n_padded_needed = (pt_frames - 1) * hop + fft_size;
|
||||
const int total_pad = std::max((int)(n_padded_needed - (int)chunk_len), pad_left);
|
||||
std::vector<float> padded_samples(total_pad + chunk_len, 0.0f);
|
||||
std::copy(chunk_ptr, chunk_ptr + chunk_len, padded_samples.data() + pad_left);
|
||||
|
||||
mtmd_audio_mel out_chunk;
|
||||
bool ok = log_mel_spectrogram(padded_samples.data(), padded_samples.size(), 4, params, cache, out_chunk);
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Trim to PyTorch frame count
|
||||
out_chunk.n_len = std::min(out_chunk.n_len, pt_frames);
|
||||
|
||||
output.push_back(std::move(out_chunk));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//
|
||||
// mtmd_audio_streaming_istft implementation
|
||||
//
|
||||
|
||||
@@ -45,7 +45,8 @@ struct mtmd_audio_cache {
|
||||
float fmin = 0.0f, // e.g. 0.0
|
||||
float fmax = -1.0f, // e.g. sr/2; pass -1 for auto
|
||||
bool slaney_area_norm = true,
|
||||
float scale = 1.0f // optional extra scaling
|
||||
float scale = 1.0f,
|
||||
bool use_htk = false
|
||||
);
|
||||
};
|
||||
|
||||
@@ -77,6 +78,15 @@ struct mtmd_audio_preprocessor_conformer : mtmd_audio_preprocessor {
|
||||
mtmd_audio_cache cache;
|
||||
};
|
||||
|
||||
struct mtmd_audio_preprocessor_gemma4a : mtmd_audio_preprocessor {
|
||||
mtmd_audio_preprocessor_gemma4a(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {}
|
||||
void initialize() override;
|
||||
bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override;
|
||||
|
||||
private:
|
||||
mtmd_audio_cache cache;
|
||||
};
|
||||
|
||||
//
|
||||
// streaming ISTFT - converts spectrogram frames back to audio one frame at a time
|
||||
//
|
||||
|
||||
@@ -274,7 +274,8 @@ int32_t mtmd_helper_decode_image_chunk(
|
||||
batch_embd.set_position_normal(n_past, seq_id);
|
||||
}
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
const bool use_non_causal = mtmd_decode_use_non_causal(ctx, chunk);
|
||||
if (use_non_causal) {
|
||||
llama_set_causal_attn(lctx, false);
|
||||
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
|
||||
}
|
||||
@@ -302,7 +303,7 @@ int32_t mtmd_helper_decode_image_chunk(
|
||||
n_past += mtmd_input_chunk_get_n_pos(chunk);
|
||||
*new_n_past = n_past;
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
if (use_non_causal) {
|
||||
llama_set_causal_attn(lctx, true);
|
||||
}
|
||||
return 0;
|
||||
|
||||
@@ -198,35 +198,38 @@ struct img_tool {
|
||||
private:
|
||||
// Bilinear resize function
|
||||
static void resize_bilinear(const clip_image_u8 & src, clip_image_u8 & dst, int target_width, int target_height) {
|
||||
GGML_ASSERT(src.nx >= 2 && src.ny >= 2);
|
||||
if (src.nx == 0 || src.ny == 0) { dst.nx = dst.ny = 0; dst.buf.clear(); return; }
|
||||
if (target_width <= 0) target_width = 1;
|
||||
if (target_height <= 0) target_height = 1;
|
||||
|
||||
dst.nx = target_width;
|
||||
dst.ny = target_height;
|
||||
dst.buf.resize(3 * target_width * target_height);
|
||||
|
||||
float x_ratio = static_cast<float>(src.nx - 1) / target_width;
|
||||
float y_ratio = static_cast<float>(src.ny - 1) / target_height;
|
||||
float x_ratio = target_width > 1 ? static_cast<float>(src.nx - 1) / (target_width - 1) : 0.0f;
|
||||
float y_ratio = target_height > 1 ? static_cast<float>(src.ny - 1) / (target_height - 1) : 0.0f;
|
||||
|
||||
for (int y = 0; y < target_height; y++) {
|
||||
for (int x = 0; x < target_width; x++) {
|
||||
float px = x_ratio * x;
|
||||
float py = y_ratio * y;
|
||||
int x_floor = std::min(static_cast<int>(px), src.nx - 2);
|
||||
int y_floor = std::min(static_cast<int>(py), src.ny - 2);
|
||||
float x_lerp = px - x_floor;
|
||||
float y_lerp = py - y_floor;
|
||||
for (int y = 0; y < target_height; ++y) {
|
||||
for (int x = 0; x < target_width; ++x) {
|
||||
float px = x * x_ratio;
|
||||
float py = y * y_ratio;
|
||||
|
||||
for (int c = 0; c < 3; c++) {
|
||||
float top = lerp(
|
||||
static_cast<float>(src.buf[3 * (y_floor * src.nx + x_floor) + c]),
|
||||
static_cast<float>(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]),
|
||||
x_lerp
|
||||
);
|
||||
float bottom = lerp(
|
||||
static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]),
|
||||
static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]),
|
||||
x_lerp
|
||||
);
|
||||
dst.buf[3 * (y * target_width + x) + c] = static_cast<uint8_t>(lerp(top, bottom, y_lerp));
|
||||
int x0 = std::min(static_cast<int>(px), src.nx - 1);
|
||||
int y0 = std::min(static_cast<int>(py), src.ny - 1);
|
||||
int x1 = std::min(x0 + 1, src.nx - 1);
|
||||
int y1 = std::min(y0 + 1, src.ny - 1);
|
||||
|
||||
float xf = px - x0;
|
||||
float yf = py - y0;
|
||||
|
||||
for (int c = 0; c < 3; ++c) {
|
||||
float top = lerp(static_cast<float>(src.buf[3 * (y0 * src.nx + x0) + c]),
|
||||
static_cast<float>(src.buf[3 * (y0 * src.nx + x1) + c]),
|
||||
xf);
|
||||
float bottom = lerp(static_cast<float>(src.buf[3 * (y1 * src.nx + x0) + c]),
|
||||
static_cast<float>(src.buf[3 * (y1 * src.nx + x1) + c]),
|
||||
xf);
|
||||
dst.buf[3 * (y * target_width + x) + c] = static_cast<uint8_t>(lerp(top, bottom, yf));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -455,6 +455,7 @@ struct mtmd_context {
|
||||
// set preprocessor
|
||||
switch (proj) {
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
case PROJECTOR_TYPE_QWEN25O:
|
||||
{
|
||||
// <|audio_bos|> ... (embeddings) ... <|audio_eos|>
|
||||
@@ -484,6 +485,12 @@ struct mtmd_context {
|
||||
{
|
||||
audio_preproc = std::make_unique<mtmd_audio_preprocessor_conformer>(ctx_a);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
aud_beg = "<|audio>";
|
||||
aud_end = "<audio|>";
|
||||
audio_preproc = std::make_unique<mtmd_audio_preprocessor_gemma4a>(ctx_a);
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error(string_format("%s: unexpected audio projector type %d\n", __func__, proj));
|
||||
}
|
||||
@@ -1010,8 +1017,12 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
|
||||
return ctx->image_embd_v.data();
|
||||
}
|
||||
|
||||
bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
|
||||
switch (ctx->proj_type_v()) {
|
||||
bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
|
||||
auto proj_type = ctx->proj_type_v();
|
||||
if (chunk && chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||
proj_type = ctx->proj_type_a();
|
||||
}
|
||||
switch (proj_type) {
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_GEMMA4V:
|
||||
return true;
|
||||
@@ -1021,6 +1032,10 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
|
||||
}
|
||||
|
||||
bool mtmd_decode_use_mrope(mtmd_context * ctx) {
|
||||
if (ctx->ctx_v == nullptr && ctx->proj_type_a() == PROJECTOR_TYPE_QWEN3A) {
|
||||
// qwen3-asr
|
||||
return true;
|
||||
}
|
||||
switch (ctx->proj_type_v()) {
|
||||
case PROJECTOR_TYPE_QWEN2VL:
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
|
||||
@@ -114,7 +114,8 @@ MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
||||
MTMD_API void mtmd_free(mtmd_context * ctx);
|
||||
|
||||
// whether we need to set non-causal mask before llama_decode
|
||||
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
|
||||
// if chunk is nullptr, we assume the default case where chunk is an image chunk
|
||||
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk);
|
||||
|
||||
// whether the current model use M-RoPE for llama_decode
|
||||
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
|
||||
|
||||
@@ -91,11 +91,14 @@ add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0"
|
||||
add_test_vision "ggml-org/DeepSeek-OCR-GGUF:Q8_0" -p "Free OCR." --chat-template deepseek-ocr
|
||||
add_test_vision "ggml-org/dots.ocr-GGUF:Q8_0" -p "OCR"
|
||||
add_test_vision "ggml-org/HunyuanOCR-GGUF:Q8_0" -p "OCR"
|
||||
add_test_vision "ggml-org/gemma-4-E2B-it-GGUF:Q8_0" --jinja
|
||||
|
||||
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
|
||||
add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
|
||||
add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M"
|
||||
add_test_audio "ggml-org/LFM2-Audio-1.5B-GGUF:Q8_0"
|
||||
add_test_audio "ggml-org/gemma-4-E2B-it-GGUF:Q8_0" --jinja
|
||||
add_test_audio "ggml-org/Qwen3-ASR-0.6B-GGUF:Q8_0"
|
||||
|
||||
# to test the big models, run: ./tests.sh big
|
||||
if [ "$RUN_BIG_TESTS" = true ]; then
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -926,7 +926,8 @@ void server_models_routes::init_routes() {
|
||||
res_ok(res, {
|
||||
// TODO: add support for this on web UI
|
||||
{"role", "router"},
|
||||
{"max_instances", 4}, // dummy value for testing
|
||||
{"max_instances", params.models_max},
|
||||
{"models_autoload", params.models_autoload},
|
||||
// this is a dummy response to make sure webui doesn't break
|
||||
{"model_alias", "llama-server"},
|
||||
{"model_path", "none"},
|
||||
@@ -935,6 +936,7 @@ void server_models_routes::init_routes() {
|
||||
{"n_ctx", 0},
|
||||
}},
|
||||
{"webui_settings", webui_settings},
|
||||
{"build_info", build_info},
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -9,6 +9,19 @@ def create_server():
|
||||
server = ServerPreset.router()
|
||||
|
||||
|
||||
def test_router_props():
|
||||
global server
|
||||
server.models_max = 2
|
||||
server.no_models_autoload = True
|
||||
server.start()
|
||||
res = server.make_request("GET", "/props")
|
||||
assert res.status_code == 200
|
||||
assert res.body["role"] == "router"
|
||||
assert res.body["max_instances"] == 2
|
||||
assert res.body["models_autoload"] is False
|
||||
assert res.body["build_info"].startswith("b")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,success",
|
||||
[
|
||||
|
||||
@@ -9,7 +9,14 @@
|
||||
import { getMessageEditContext } from '$lib/contexts';
|
||||
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
|
||||
import { isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
|
||||
import { autoResizeTextarea, copyToClipboard, isIMEComposing } from '$lib/utils';
|
||||
import {
|
||||
autoResizeTextarea,
|
||||
copyToClipboard,
|
||||
isIMEComposing,
|
||||
deriveAgenticSections
|
||||
} from '$lib/utils';
|
||||
import { AgenticSectionType } from '$lib/enums';
|
||||
import { REASONING_TAGS } from '$lib/constants/agentic';
|
||||
import { tick } from 'svelte';
|
||||
import { fade } from 'svelte/transition';
|
||||
import { Check, X } from '@lucide/svelte';
|
||||
@@ -95,6 +102,49 @@
|
||||
let currentConfig = $derived(config());
|
||||
let isRouter = $derived(isRouterMode());
|
||||
let showRawOutput = $state(false);
|
||||
|
||||
let rawOutputContent = $derived.by(() => {
|
||||
const sections = deriveAgenticSections(message, toolMessages, [], false);
|
||||
const parts: string[] = [];
|
||||
|
||||
for (const section of sections) {
|
||||
switch (section.type) {
|
||||
case AgenticSectionType.REASONING:
|
||||
case AgenticSectionType.REASONING_PENDING:
|
||||
parts.push(`${REASONING_TAGS.START}\n${section.content}\n${REASONING_TAGS.END}`);
|
||||
break;
|
||||
|
||||
case AgenticSectionType.TEXT:
|
||||
parts.push(section.content);
|
||||
break;
|
||||
|
||||
case AgenticSectionType.TOOL_CALL:
|
||||
case AgenticSectionType.TOOL_CALL_PENDING:
|
||||
case AgenticSectionType.TOOL_CALL_STREAMING: {
|
||||
const callObj: Record<string, unknown> = { name: section.toolName };
|
||||
|
||||
if (section.toolArgs) {
|
||||
try {
|
||||
callObj.arguments = JSON.parse(section.toolArgs);
|
||||
} catch {
|
||||
callObj.arguments = section.toolArgs;
|
||||
}
|
||||
}
|
||||
|
||||
parts.push(JSON.stringify(callObj, null, 2));
|
||||
|
||||
if (section.toolResult) {
|
||||
parts.push(`[Tool Result]\n${section.toolResult}`);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parts.join('\n\n\n');
|
||||
});
|
||||
|
||||
let activeStatsView = $state<ChatMessageStatsView>(ChatMessageStatsView.GENERATION);
|
||||
let statsContainerEl: HTMLDivElement | undefined = $state();
|
||||
|
||||
@@ -252,7 +302,7 @@
|
||||
</div>
|
||||
{:else if message.role === MessageRole.ASSISTANT}
|
||||
{#if showRawOutput}
|
||||
<pre class="raw-output">{messageContent || ''}</pre>
|
||||
<pre class="raw-output">{rawOutputContent || ''}</pre>
|
||||
{:else}
|
||||
<ChatMessageAgenticContent
|
||||
{message}
|
||||
|
||||
@@ -89,6 +89,11 @@
|
||||
key: SETTINGS_KEYS.ASK_FOR_TITLE_CONFIRMATION,
|
||||
label: 'Ask for confirmation before changing conversation title',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.TITLE_GENERATION_USE_FIRST_LINE,
|
||||
label: 'Use first non-empty line for conversation title',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@@ -15,6 +15,18 @@
|
||||
let { logs, connectionTimeMs, defaultExpanded = false, class: className }: Props = $props();
|
||||
|
||||
let isExpanded = $derived(defaultExpanded);
|
||||
|
||||
function formatLogDetails(details: unknown): string {
|
||||
if (details == null) {
|
||||
return '';
|
||||
}
|
||||
|
||||
try {
|
||||
return JSON.stringify(details, null, 2);
|
||||
} catch {
|
||||
return String(details);
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if logs.length > 0}
|
||||
@@ -53,6 +65,16 @@
|
||||
|
||||
<span class="break-all">{log.message}</span>
|
||||
</div>
|
||||
|
||||
{#if log.details !== undefined}
|
||||
<details class="ml-11">
|
||||
<summary class="cursor-pointer text-[10px] text-muted-foreground"> details </summary>
|
||||
|
||||
<pre
|
||||
class="mt-1 overflow-x-auto rounded bg-background/70 p-2 text-[10px] break-all whitespace-pre-wrap text-foreground/80">
|
||||
{formatLogDetails(log.details)}</pre>
|
||||
</details>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
</Collapsible.Content>
|
||||
|
||||
@@ -15,6 +15,11 @@ export const DEFAULT_AGENTIC_CONFIG: AgenticConfig = {
|
||||
maxToolPreviewLines: 25
|
||||
} as const;
|
||||
|
||||
export const REASONING_TAGS = {
|
||||
START: '<think>',
|
||||
END: '</think>'
|
||||
} as const;
|
||||
|
||||
/**
|
||||
* @deprecated Legacy marker tags - only used for migration of old stored messages.
|
||||
* New messages use structured fields (reasoningContent, toolCalls, toolCallId).
|
||||
|
||||
@@ -48,6 +48,26 @@ export const EXPECTED_THEMED_ICON_PAIR_COUNT = 2;
|
||||
/** CORS proxy URL query parameter name */
|
||||
export const CORS_PROXY_URL_PARAM = 'url';
|
||||
|
||||
/** Number of trailing characters to keep visible when partially redacting mcp-session-id */
|
||||
export const MCP_SESSION_ID_VISIBLE_CHARS = 5;
|
||||
|
||||
/** Partial-redaction rules for MCP headers: header name -> visible trailing chars */
|
||||
export const MCP_PARTIAL_REDACT_HEADERS = new Map<string, number>([
|
||||
['mcp-session-id', MCP_SESSION_ID_VISIBLE_CHARS]
|
||||
]);
|
||||
|
||||
/** Header names whose values should be redacted in diagnostic logs */
|
||||
export const REDACTED_HEADERS = new Set([
|
||||
'authorization',
|
||||
'api-key',
|
||||
'cookie',
|
||||
'mcp-session-id',
|
||||
'proxy-authorization',
|
||||
'set-cookie',
|
||||
'x-auth-token',
|
||||
'x-api-key'
|
||||
]);
|
||||
|
||||
/** Human-readable labels for MCP transport types */
|
||||
export const MCP_TRANSPORT_LABELS: Record<MCPTransportType, string> = {
|
||||
[MCPTransportType.WEBSOCKET]: 'WebSocket',
|
||||
|
||||
@@ -15,6 +15,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean |
|
||||
keepStatsVisible: false,
|
||||
showMessageStats: true,
|
||||
askForTitleConfirmation: false,
|
||||
titleGenerationUseFirstLine: false,
|
||||
pasteLongTextToFileLen: 2500,
|
||||
copyTextAttachmentsAsPlainText: false,
|
||||
pdfAsImage: false,
|
||||
@@ -118,6 +119,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
'Display generation statistics (tokens/second, token count, duration) below each assistant message.',
|
||||
askForTitleConfirmation:
|
||||
'Ask for confirmation before automatically changing conversation title when editing the first message.',
|
||||
titleGenerationUseFirstLine:
|
||||
'Use only the first non-empty line of the prompt to generate the conversation title.',
|
||||
pdfAsImage:
|
||||
'Parse PDF as image instead of text. Automatically falls back to text processing for non-vision models.',
|
||||
disableAutoScroll:
|
||||
|
||||
@@ -15,6 +15,7 @@ export const SETTINGS_KEYS = {
|
||||
ENABLE_CONTINUE_GENERATION: 'enableContinueGeneration',
|
||||
PDF_AS_IMAGE: 'pdfAsImage',
|
||||
ASK_FOR_TITLE_CONFIRMATION: 'askForTitleConfirmation',
|
||||
TITLE_GENERATION_USE_FIRST_LINE: 'titleGenerationUseFirstLine',
|
||||
// Display
|
||||
SHOW_MESSAGE_STATS: 'showMessageStats',
|
||||
SHOW_THOUGHT_IN_PROGRESS: 'showThoughtInProgress',
|
||||
|
||||
@@ -15,7 +15,8 @@ import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import {
|
||||
DEFAULT_MCP_CONFIG,
|
||||
DEFAULT_CLIENT_VERSION,
|
||||
DEFAULT_IMAGE_MIME_TYPE
|
||||
DEFAULT_IMAGE_MIME_TYPE,
|
||||
MCP_PARTIAL_REDACT_HEADERS
|
||||
} from '$lib/constants';
|
||||
import {
|
||||
MCPConnectionPhase,
|
||||
@@ -43,9 +44,17 @@ import {
|
||||
buildProxiedUrl,
|
||||
buildProxiedHeaders,
|
||||
getAuthHeaders,
|
||||
sanitizeHeaders,
|
||||
throwIfAborted,
|
||||
isAbortError,
|
||||
createBase64DataUrl
|
||||
createBase64DataUrl,
|
||||
getRequestUrl,
|
||||
getRequestMethod,
|
||||
getRequestBody,
|
||||
summarizeRequestBody,
|
||||
formatDiagnosticErrorMessage,
|
||||
extractJsonRpcMethods,
|
||||
type RequestBodySummary
|
||||
} from '$lib/utils';
|
||||
|
||||
interface ToolResultContentItem {
|
||||
@@ -62,6 +71,16 @@ interface ToolCallResult {
|
||||
_meta?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
interface DiagnosticRequestDetails {
|
||||
url: string;
|
||||
method: string;
|
||||
credentials?: RequestCredentials;
|
||||
mode?: RequestMode;
|
||||
headers: Record<string, string>;
|
||||
body: RequestBodySummary;
|
||||
jsonRpcMethods?: string[];
|
||||
}
|
||||
|
||||
export class MCPService {
|
||||
/**
|
||||
* Create a connection log entry for phase tracking.
|
||||
@@ -87,6 +106,225 @@ export class MCPService {
|
||||
};
|
||||
}
|
||||
|
||||
private static createDiagnosticRequestDetails(
|
||||
input: RequestInfo | URL,
|
||||
init: RequestInit | undefined,
|
||||
baseInit: RequestInit,
|
||||
requestHeaders: Headers,
|
||||
extraRedactedHeaders?: Iterable<string>
|
||||
): DiagnosticRequestDetails {
|
||||
const body = getRequestBody(input, init);
|
||||
const details: DiagnosticRequestDetails = {
|
||||
url: getRequestUrl(input),
|
||||
method: getRequestMethod(input, init, baseInit).toUpperCase(),
|
||||
credentials: init?.credentials ?? baseInit.credentials,
|
||||
mode: init?.mode ?? baseInit.mode,
|
||||
headers: sanitizeHeaders(requestHeaders, extraRedactedHeaders, MCP_PARTIAL_REDACT_HEADERS),
|
||||
body: summarizeRequestBody(body)
|
||||
};
|
||||
const jsonRpcMethods = extractJsonRpcMethods(body);
|
||||
|
||||
if (jsonRpcMethods) {
|
||||
details.jsonRpcMethods = jsonRpcMethods;
|
||||
}
|
||||
|
||||
return details;
|
||||
}
|
||||
|
||||
private static summarizeError(error: unknown): Record<string, unknown> {
|
||||
if (error instanceof Error) {
|
||||
return {
|
||||
name: error.name,
|
||||
message: error.message,
|
||||
cause:
|
||||
error.cause instanceof Error
|
||||
? { name: error.cause.name, message: error.cause.message }
|
||||
: error.cause,
|
||||
stack: error.stack?.split('\n').slice(0, 6).join('\n')
|
||||
};
|
||||
}
|
||||
|
||||
return { value: String(error) };
|
||||
}
|
||||
|
||||
private static getBrowserContext(
|
||||
targetUrl: URL,
|
||||
useProxy: boolean
|
||||
): Record<string, unknown> | undefined {
|
||||
if (typeof window === 'undefined') {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
location: window.location.href,
|
||||
origin: window.location.origin,
|
||||
protocol: window.location.protocol,
|
||||
isSecureContext: window.isSecureContext,
|
||||
targetOrigin: targetUrl.origin,
|
||||
targetProtocol: targetUrl.protocol,
|
||||
sameOrigin: window.location.origin === targetUrl.origin,
|
||||
useProxy
|
||||
};
|
||||
}
|
||||
|
||||
private static getConnectionHints(
|
||||
targetUrl: URL,
|
||||
config: MCPServerConfig,
|
||||
error: unknown
|
||||
): string[] {
|
||||
const hints: string[] = [];
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
const headerNames = Object.keys(config.headers ?? {});
|
||||
|
||||
if (typeof window !== 'undefined') {
|
||||
if (
|
||||
window.location.protocol === 'https:' &&
|
||||
targetUrl.protocol === 'http:' &&
|
||||
!config.useProxy
|
||||
) {
|
||||
hints.push(
|
||||
'The page is running over HTTPS but the MCP server is HTTP. Browsers often block this as mixed content; enable the proxy or use HTTPS/WSS for the MCP server.'
|
||||
);
|
||||
}
|
||||
|
||||
if (window.location.origin !== targetUrl.origin && !config.useProxy) {
|
||||
hints.push(
|
||||
'This is a cross-origin browser request. If the server is reachable from curl or Node but not from the browser, missing CORS headers are the most likely cause.'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (headerNames.length > 0) {
|
||||
hints.push(
|
||||
`Custom request headers are configured (${headerNames.join(', ')}). That triggers a CORS preflight, so the server must allow OPTIONS and include the matching Access-Control-Allow-Headers response.`
|
||||
);
|
||||
}
|
||||
|
||||
if (config.credentials && config.credentials !== 'omit') {
|
||||
hints.push(
|
||||
'Credentials are enabled for this connection. Cross-origin credentialed requests need Access-Control-Allow-Credentials: true and cannot use a wildcard Access-Control-Allow-Origin.'
|
||||
);
|
||||
}
|
||||
|
||||
if (message.includes('Failed to fetch')) {
|
||||
hints.push(
|
||||
'"Failed to fetch" is a browser-level network failure. Common causes are CORS rejection, mixed-content blocking, certificate/TLS errors, DNS failures, or nothing listening on the target port.'
|
||||
);
|
||||
}
|
||||
|
||||
return hints;
|
||||
}
|
||||
|
||||
private static createDiagnosticFetch(
|
||||
serverName: string,
|
||||
config: MCPServerConfig,
|
||||
baseInit: RequestInit,
|
||||
targetUrl: URL,
|
||||
useProxy: boolean,
|
||||
onLog?: (log: MCPConnectionLog) => void
|
||||
): {
|
||||
fetch: typeof fetch;
|
||||
disable: () => void;
|
||||
} {
|
||||
let enabled = true;
|
||||
const logIfEnabled = (log: MCPConnectionLog) => {
|
||||
if (enabled) {
|
||||
onLog?.(log);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
fetch: async (input, init) => {
|
||||
const startedAt = performance.now();
|
||||
const requestHeaders = new Headers(baseInit.headers);
|
||||
|
||||
if (typeof Request !== 'undefined' && input instanceof Request) {
|
||||
for (const [key, value] of input.headers.entries()) {
|
||||
requestHeaders.set(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
if (init?.headers) {
|
||||
for (const [key, value] of new Headers(init.headers).entries()) {
|
||||
requestHeaders.set(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
const request = this.createDiagnosticRequestDetails(
|
||||
input,
|
||||
init,
|
||||
baseInit,
|
||||
requestHeaders,
|
||||
Object.keys(config.headers ?? {})
|
||||
);
|
||||
const { method, url } = request;
|
||||
|
||||
logIfEnabled(
|
||||
this.createLog(
|
||||
MCPConnectionPhase.INITIALIZING,
|
||||
`HTTP ${method} ${url}`,
|
||||
MCPLogLevel.INFO,
|
||||
{
|
||||
serverName,
|
||||
request
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
try {
|
||||
const response = await fetch(input, {
|
||||
...baseInit,
|
||||
...init,
|
||||
headers: requestHeaders
|
||||
});
|
||||
const durationMs = Math.round(performance.now() - startedAt);
|
||||
|
||||
logIfEnabled(
|
||||
this.createLog(
|
||||
MCPConnectionPhase.INITIALIZING,
|
||||
`HTTP ${response.status} ${method} ${url} (${durationMs}ms)`,
|
||||
response.ok ? MCPLogLevel.INFO : MCPLogLevel.WARN,
|
||||
{
|
||||
response: {
|
||||
url,
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
headers: sanitizeHeaders(response.headers, undefined, MCP_PARTIAL_REDACT_HEADERS),
|
||||
durationMs
|
||||
}
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
return response;
|
||||
} catch (error) {
|
||||
const durationMs = Math.round(performance.now() - startedAt);
|
||||
|
||||
logIfEnabled(
|
||||
this.createLog(
|
||||
MCPConnectionPhase.ERROR,
|
||||
`HTTP ${method} ${url} failed: ${formatDiagnosticErrorMessage(error)}`,
|
||||
MCPLogLevel.ERROR,
|
||||
{
|
||||
serverName,
|
||||
request,
|
||||
error: this.summarizeError(error),
|
||||
browser: this.getBrowserContext(targetUrl, useProxy),
|
||||
hints: this.getConnectionHints(targetUrl, config, error),
|
||||
durationMs
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
disable: () => {
|
||||
enabled = false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect if an error indicates an expired/invalidated MCP session.
|
||||
* Per MCP spec 2025-11-25: HTTP 404 means session invalidated, client MUST
|
||||
@@ -113,9 +351,14 @@ export class MCPService {
|
||||
* @returns Object containing the created transport and the transport type used
|
||||
* @throws {Error} If url is missing, WebSocket + proxy combination, or all transports fail
|
||||
*/
|
||||
static createTransport(config: MCPServerConfig): {
|
||||
static createTransport(
|
||||
serverName: string,
|
||||
config: MCPServerConfig,
|
||||
onLog?: (log: MCPConnectionLog) => void
|
||||
): {
|
||||
transport: Transport;
|
||||
type: MCPTransportType;
|
||||
stopPhaseLogging: () => void;
|
||||
} {
|
||||
if (!config.url) {
|
||||
throw new Error('MCP server configuration is missing url');
|
||||
@@ -154,11 +397,20 @@ export class MCPService {
|
||||
|
||||
return {
|
||||
transport: new WebSocketClientTransport(url),
|
||||
type: MCPTransportType.WEBSOCKET
|
||||
type: MCPTransportType.WEBSOCKET,
|
||||
stopPhaseLogging: () => {}
|
||||
};
|
||||
}
|
||||
|
||||
const url = useProxy ? buildProxiedUrl(config.url) : new URL(config.url);
|
||||
const { fetch: diagnosticFetch, disable: stopPhaseLogging } = this.createDiagnosticFetch(
|
||||
serverName,
|
||||
config,
|
||||
requestInit,
|
||||
url,
|
||||
useProxy,
|
||||
onLog
|
||||
);
|
||||
|
||||
if (useProxy && import.meta.env.DEV) {
|
||||
console.log(`[MCPService] Using CORS proxy for ${config.url} -> ${url.href}`);
|
||||
@@ -171,17 +423,24 @@ export class MCPService {
|
||||
|
||||
return {
|
||||
transport: new StreamableHTTPClientTransport(url, {
|
||||
requestInit
|
||||
requestInit,
|
||||
fetch: diagnosticFetch
|
||||
}),
|
||||
type: MCPTransportType.STREAMABLE_HTTP
|
||||
type: MCPTransportType.STREAMABLE_HTTP,
|
||||
stopPhaseLogging
|
||||
};
|
||||
} catch (httpError) {
|
||||
console.warn(`[MCPService] StreamableHTTP failed, trying SSE transport...`, httpError);
|
||||
|
||||
try {
|
||||
return {
|
||||
transport: new SSEClientTransport(url, { requestInit }),
|
||||
type: MCPTransportType.SSE
|
||||
transport: new SSEClientTransport(url, {
|
||||
requestInit,
|
||||
fetch: diagnosticFetch,
|
||||
eventSourceInit: { fetch: diagnosticFetch }
|
||||
}),
|
||||
type: MCPTransportType.SSE,
|
||||
stopPhaseLogging
|
||||
};
|
||||
} catch (sseError) {
|
||||
const httpMsg = httpError instanceof Error ? httpError.message : String(httpError);
|
||||
@@ -263,7 +522,11 @@ export class MCPService {
|
||||
console.log(`[MCPService][${serverName}] Creating transport...`);
|
||||
}
|
||||
|
||||
const { transport, type: transportType } = this.createTransport(serverConfig);
|
||||
const {
|
||||
transport,
|
||||
type: transportType,
|
||||
stopPhaseLogging
|
||||
} = this.createTransport(serverName, serverConfig, (log) => onPhase?.(log.phase, log));
|
||||
|
||||
// Setup WebSocket reconnection handler
|
||||
if (transportType === MCPTransportType.WEBSOCKET) {
|
||||
@@ -294,6 +557,24 @@ export class MCPService {
|
||||
}
|
||||
);
|
||||
|
||||
const runtimeErrorHandler = (error: Error) => {
|
||||
console.error(`[MCPService][${serverName}] Protocol error after initialize:`, error);
|
||||
};
|
||||
|
||||
client.onerror = (error) => {
|
||||
onPhase?.(
|
||||
MCPConnectionPhase.ERROR,
|
||||
this.createLog(
|
||||
MCPConnectionPhase.ERROR,
|
||||
`Protocol error: ${error.message}`,
|
||||
MCPLogLevel.ERROR,
|
||||
{
|
||||
error: this.summarizeError(error)
|
||||
}
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
// Phase: Initializing
|
||||
onPhase?.(
|
||||
MCPConnectionPhase.INITIALIZING,
|
||||
@@ -301,7 +582,49 @@ export class MCPService {
|
||||
);
|
||||
|
||||
console.log(`[MCPService][${serverName}] Connecting to server...`);
|
||||
await client.connect(transport);
|
||||
try {
|
||||
await client.connect(transport);
|
||||
// Transport diagnostics are only for the initial handshake, not long-lived traffic.
|
||||
stopPhaseLogging();
|
||||
client.onerror = runtimeErrorHandler;
|
||||
} catch (error) {
|
||||
client.onerror = runtimeErrorHandler;
|
||||
const url =
|
||||
(serverConfig.useProxy ?? false)
|
||||
? buildProxiedUrl(serverConfig.url)
|
||||
: new URL(serverConfig.url);
|
||||
|
||||
onPhase?.(
|
||||
MCPConnectionPhase.ERROR,
|
||||
this.createLog(
|
||||
MCPConnectionPhase.ERROR,
|
||||
`Connection failed during initialize: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
MCPLogLevel.ERROR,
|
||||
{
|
||||
error: this.summarizeError(error),
|
||||
config: {
|
||||
serverName,
|
||||
configuredUrl: serverConfig.url,
|
||||
effectiveUrl: url.href,
|
||||
transportType,
|
||||
useProxy: serverConfig.useProxy ?? false,
|
||||
headers: sanitizeHeaders(
|
||||
serverConfig.headers,
|
||||
Object.keys(serverConfig.headers ?? {}),
|
||||
MCP_PARTIAL_REDACT_HEADERS
|
||||
),
|
||||
credentials: serverConfig.credentials
|
||||
},
|
||||
browser: this.getBrowserContext(url, serverConfig.useProxy ?? false),
|
||||
hints: this.getConnectionHints(url, serverConfig, error)
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
throw error;
|
||||
}
|
||||
|
||||
const serverVersion = client.getServerVersion();
|
||||
const serverCapabilities = client.getServerCapabilities();
|
||||
|
||||
@@ -130,6 +130,12 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'titleGenerationUseFirstLine',
|
||||
serverKey: 'titleGenerationUseFirstLine',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'disableAutoScroll',
|
||||
serverKey: 'disableAutoScroll',
|
||||
|
||||
@@ -30,7 +30,8 @@ import {
|
||||
findDescendantMessages,
|
||||
findLeafNode,
|
||||
findMessageById,
|
||||
isAbortError
|
||||
isAbortError,
|
||||
generateConversationTitle
|
||||
} from '$lib/utils';
|
||||
import {
|
||||
MAX_INACTIVE_CONVERSATION_STATES,
|
||||
@@ -504,7 +505,10 @@ class ChatStore {
|
||||
allExtras
|
||||
);
|
||||
if (isNewConversation && content)
|
||||
await conversationsStore.updateConversationName(currentConv.id, content.trim());
|
||||
await conversationsStore.updateConversationName(
|
||||
currentConv.id,
|
||||
generateConversationTitle(content, Boolean(config().titleGenerationUseFirstLine))
|
||||
);
|
||||
const assistantMessage = await this.createAssistantMessage(userMessage.id);
|
||||
conversationsStore.addMessageToActive(assistantMessage);
|
||||
await this.streamChatCompletion(
|
||||
@@ -896,7 +900,7 @@ class ChatStore {
|
||||
if (isFirstUserMessage && newContent.trim())
|
||||
await conversationsStore.updateConversationTitleWithConfirmation(
|
||||
activeConv.id,
|
||||
newContent.trim()
|
||||
generateConversationTitle(newContent, Boolean(config().titleGenerationUseFirstLine))
|
||||
);
|
||||
const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex + 1);
|
||||
for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id);
|
||||
@@ -1317,7 +1321,7 @@ class ChatStore {
|
||||
if (rootMessage && msg.parent === rootMessage.id && newContent.trim()) {
|
||||
await conversationsStore.updateConversationTitleWithConfirmation(
|
||||
activeConv.id,
|
||||
newContent.trim()
|
||||
generateConversationTitle(newContent, Boolean(config().titleGenerationUseFirstLine))
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1391,7 +1395,7 @@ class ChatStore {
|
||||
if (isFirstUserMessage && newContent.trim())
|
||||
await conversationsStore.updateConversationTitleWithConfirmation(
|
||||
activeConv.id,
|
||||
newContent.trim()
|
||||
generateConversationTitle(newContent, Boolean(config().titleGenerationUseFirstLine))
|
||||
);
|
||||
await conversationsStore.refreshActiveMessages();
|
||||
if (msg.role === MessageRole.USER)
|
||||
|
||||
@@ -23,7 +23,12 @@ import { browser } from '$app/environment';
|
||||
import { toast } from 'svelte-sonner';
|
||||
import { DatabaseService } from '$lib/services/database.service';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { filterByLeafNodeId, findLeafNode, runLegacyMigration } from '$lib/utils';
|
||||
import {
|
||||
filterByLeafNodeId,
|
||||
findLeafNode,
|
||||
runLegacyMigration,
|
||||
generateConversationTitle
|
||||
} from '$lib/utils';
|
||||
import type { McpServerOverride } from '$lib/types/database';
|
||||
import { MessageRole } from '$lib/enums';
|
||||
import {
|
||||
@@ -548,7 +553,10 @@ class ConversationsStore {
|
||||
) {
|
||||
await this.updateConversationTitleWithConfirmation(
|
||||
this.activeConversation.id,
|
||||
newFirstUserMessage.content.trim()
|
||||
generateConversationTitle(
|
||||
newFirstUserMessage.content,
|
||||
Boolean(config().titleGenerationUseFirstLine)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1460,12 +1460,14 @@ class MCPStore {
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : 'Unknown error occurred';
|
||||
|
||||
logs.push({
|
||||
timestamp: new Date(),
|
||||
phase: MCPConnectionPhase.ERROR,
|
||||
message: `Connection failed: ${message}`,
|
||||
level: MCPLogLevel.ERROR
|
||||
});
|
||||
if (logs.at(-1)?.phase !== MCPConnectionPhase.ERROR) {
|
||||
logs.push({
|
||||
timestamp: new Date(),
|
||||
phase: MCPConnectionPhase.ERROR,
|
||||
message: `Connection failed: ${message}`,
|
||||
level: MCPLogLevel.ERROR
|
||||
});
|
||||
}
|
||||
|
||||
this.updateHealthCheck(server.id, {
|
||||
status: HealthCheckStatus.ERROR,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { REDACTED_HEADERS } from '$lib/constants';
|
||||
import { redactValue } from './redact';
|
||||
|
||||
/**
|
||||
* Get authorization headers for API requests
|
||||
@@ -20,3 +22,46 @@ export function getJsonHeaders(): Record<string, string> {
|
||||
...getAuthHeaders()
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitize HTTP headers by redacting sensitive values.
|
||||
* Known sensitive headers (from REDACTED_HEADERS) and any extra headers
|
||||
* specified by the caller are fully redacted. Headers listed in
|
||||
* `partialRedactHeaders` are partially redacted, showing only the
|
||||
* specified number of trailing characters.
|
||||
*
|
||||
* @param headers - Headers to sanitize
|
||||
* @param extraRedactedHeaders - Additional header names to fully redact
|
||||
* @param partialRedactHeaders - Map of header name -> number of trailing chars to keep visible
|
||||
* @returns Object with header names as keys and (possibly redacted) values
|
||||
*/
|
||||
export function sanitizeHeaders(
|
||||
headers?: HeadersInit,
|
||||
extraRedactedHeaders?: Iterable<string>,
|
||||
partialRedactHeaders?: Map<string, number>
|
||||
): Record<string, string> {
|
||||
if (!headers) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const normalized = new Headers(headers);
|
||||
const sanitized: Record<string, string> = {};
|
||||
const redactedHeaders = new Set(
|
||||
Array.from(extraRedactedHeaders ?? [], (header) => header.toLowerCase())
|
||||
);
|
||||
|
||||
for (const [key, value] of normalized.entries()) {
|
||||
const normalizedKey = key.toLowerCase();
|
||||
const partialChars = partialRedactHeaders?.get(normalizedKey);
|
||||
|
||||
if (partialChars !== undefined) {
|
||||
sanitized[key] = redactValue(value, partialChars);
|
||||
} else if (REDACTED_HEADERS.has(normalizedKey) || redactedHeaders.has(normalizedKey)) {
|
||||
sanitized[key] = redactValue(value);
|
||||
} else {
|
||||
sanitized[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return sanitized;
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
*/
|
||||
|
||||
// API utilities
|
||||
export { getAuthHeaders, getJsonHeaders } from './api-headers';
|
||||
export { getAuthHeaders, getJsonHeaders, sanitizeHeaders } from './api-headers';
|
||||
export { apiFetch, apiFetchWithParams, apiPost, type ApiFetchOptions } from './api-fetch';
|
||||
export { validateApiKey } from './api-key-validation';
|
||||
|
||||
@@ -55,7 +55,7 @@ export {
|
||||
|
||||
// File preview utilities
|
||||
export { getFileTypeLabel } from './file-preview';
|
||||
export { getPreviewText } from './text';
|
||||
export { getPreviewText, generateConversationTitle } from './text';
|
||||
|
||||
// File type utilities
|
||||
export {
|
||||
@@ -164,6 +164,20 @@ export { runLegacyMigration, isMigrationNeeded } from './legacy-migration';
|
||||
// Cache utilities
|
||||
export { TTLCache, ReactiveTTLMap, type TTLCacheOptions } from './cache-ttl';
|
||||
|
||||
// Redaction utilities
|
||||
export { redactValue } from './redact';
|
||||
|
||||
// Request inspection utilities
|
||||
export {
|
||||
getRequestUrl,
|
||||
getRequestMethod,
|
||||
getRequestBody,
|
||||
summarizeRequestBody,
|
||||
formatDiagnosticErrorMessage,
|
||||
extractJsonRpcMethods,
|
||||
type RequestBodySummary
|
||||
} from './request-helpers';
|
||||
|
||||
// Abort signal utilities
|
||||
export {
|
||||
throwIfAborted,
|
||||
|
||||
14
tools/server/webui/src/lib/utils/redact.ts
Normal file
14
tools/server/webui/src/lib/utils/redact.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
/**
|
||||
* Redacts a sensitive value, optionally showing the last N characters.
|
||||
*
|
||||
* @param value - The value to redact
|
||||
* @param showLastChars - If provided, reveals the last N characters with a leading mask
|
||||
* @returns The redacted string
|
||||
*/
|
||||
export function redactValue(value: string, showLastChars?: number): string {
|
||||
if (showLastChars) {
|
||||
return `....${value.slice(-showLastChars)}`;
|
||||
}
|
||||
|
||||
return '[redacted]';
|
||||
}
|
||||
111
tools/server/webui/src/lib/utils/request-helpers.ts
Normal file
111
tools/server/webui/src/lib/utils/request-helpers.ts
Normal file
@@ -0,0 +1,111 @@
|
||||
/**
|
||||
* HTTP request inspection utilities for diagnostic logging.
|
||||
* These helpers extract metadata from fetch-style request arguments
|
||||
* without exposing sensitive payload data.
|
||||
*/
|
||||
|
||||
export interface RequestBodySummary {
|
||||
kind: string;
|
||||
size?: number;
|
||||
}
|
||||
|
||||
export function getRequestUrl(input: RequestInfo | URL): string {
|
||||
if (typeof input === 'string') {
|
||||
return input;
|
||||
}
|
||||
|
||||
if (input instanceof URL) {
|
||||
return input.href;
|
||||
}
|
||||
|
||||
return input.url;
|
||||
}
|
||||
|
||||
export function getRequestMethod(
|
||||
input: RequestInfo | URL,
|
||||
init?: RequestInit,
|
||||
baseInit?: RequestInit
|
||||
): string {
|
||||
if (init?.method) {
|
||||
return init.method;
|
||||
}
|
||||
|
||||
if (typeof Request !== 'undefined' && input instanceof Request) {
|
||||
return input.method;
|
||||
}
|
||||
|
||||
return baseInit?.method ?? 'GET';
|
||||
}
|
||||
|
||||
export function getRequestBody(
|
||||
input: RequestInfo | URL,
|
||||
init?: RequestInit
|
||||
): BodyInit | null | undefined {
|
||||
if (init?.body !== undefined) {
|
||||
return init.body;
|
||||
}
|
||||
|
||||
if (typeof Request !== 'undefined' && input instanceof Request) {
|
||||
return input.body;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export function summarizeRequestBody(body: BodyInit | null | undefined): RequestBodySummary {
|
||||
if (body == null) {
|
||||
return { kind: 'empty' };
|
||||
}
|
||||
|
||||
if (typeof body === 'string') {
|
||||
return { kind: 'string', size: body.length };
|
||||
}
|
||||
|
||||
if (body instanceof Blob) {
|
||||
return { kind: 'blob', size: body.size };
|
||||
}
|
||||
|
||||
if (body instanceof URLSearchParams) {
|
||||
return { kind: 'urlsearchparams', size: body.toString().length };
|
||||
}
|
||||
|
||||
if (body instanceof FormData) {
|
||||
return { kind: 'formdata' };
|
||||
}
|
||||
|
||||
if (body instanceof ArrayBuffer) {
|
||||
return { kind: 'arraybuffer', size: body.byteLength };
|
||||
}
|
||||
|
||||
if (ArrayBuffer.isView(body)) {
|
||||
return { kind: body.constructor.name, size: body.byteLength };
|
||||
}
|
||||
|
||||
return { kind: typeof body };
|
||||
}
|
||||
|
||||
export function formatDiagnosticErrorMessage(error: unknown): string {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
|
||||
return message.includes('Failed to fetch') ? `${message} (check CORS?)` : message;
|
||||
}
|
||||
|
||||
export function extractJsonRpcMethods(body: BodyInit | null | undefined): string[] | undefined {
|
||||
if (typeof body !== 'string') {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(body);
|
||||
const messages = Array.isArray(parsed) ? parsed : [parsed];
|
||||
const methods = messages
|
||||
.map((message: Record<string, unknown>) =>
|
||||
typeof message?.method === 'string' ? (message.method as string) : undefined
|
||||
)
|
||||
.filter((method: string | undefined): method is string => Boolean(method));
|
||||
|
||||
return methods.length > 0 ? methods : undefined;
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
import { NEWLINE_SEPARATOR } from '$lib/constants';
|
||||
|
||||
/**
|
||||
* Returns a shortened preview of the provided content capped at the given length.
|
||||
* Appends an ellipsis when the content exceeds the maximum.
|
||||
@@ -5,3 +7,16 @@
|
||||
export function getPreviewText(content: string, max = 150): string {
|
||||
return content.length > max ? content.slice(0, max) + '...' : content;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a single-line title from a potentially multi-line prompt.
|
||||
* Uses the first non-empty line if `useFirstLine` is true.
|
||||
*/
|
||||
export function generateConversationTitle(content: string, useFirstLine: boolean = false): string {
|
||||
if (useFirstLine) {
|
||||
const firstLine = content.split(NEWLINE_SEPARATOR).find((line) => line.trim().length > 0);
|
||||
return firstLine ? firstLine.trim() : content.trim();
|
||||
}
|
||||
|
||||
return content.trim();
|
||||
}
|
||||
|
||||
252
tools/server/webui/tests/unit/mcp-service.test.ts
Normal file
252
tools/server/webui/tests/unit/mcp-service.test.ts
Normal file
@@ -0,0 +1,252 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
import { Client } from '@modelcontextprotocol/sdk/client';
|
||||
import { MCPService } from '$lib/services/mcp.service';
|
||||
import { MCPConnectionPhase, MCPTransportType } from '$lib/enums';
|
||||
import type { MCPConnectionLog, MCPServerConfig } from '$lib/types';
|
||||
|
||||
type DiagnosticFetchFactory = (
|
||||
serverName: string,
|
||||
config: MCPServerConfig,
|
||||
baseInit: RequestInit,
|
||||
targetUrl: URL,
|
||||
useProxy: boolean,
|
||||
onLog?: (log: MCPConnectionLog) => void
|
||||
) => { fetch: typeof fetch; disable: () => void };
|
||||
|
||||
const createDiagnosticFetch = (
|
||||
config: MCPServerConfig,
|
||||
onLog?: (log: MCPConnectionLog) => void,
|
||||
baseInit: RequestInit = {}
|
||||
) =>
|
||||
(
|
||||
MCPService as unknown as { createDiagnosticFetch: DiagnosticFetchFactory }
|
||||
).createDiagnosticFetch('test-server', config, baseInit, new URL(config.url), false, onLog);
|
||||
|
||||
describe('MCPService', () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it('stops transport phase logging after handshake diagnostics are disabled', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const response = new Response('{}', {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'application/json' }
|
||||
});
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://example.com/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
|
||||
|
||||
await controller.fetch(config.url, { method: 'POST', body: '{}' });
|
||||
expect(logs).toHaveLength(2);
|
||||
expect(logs.every((log) => log.message.includes('https://example.com/mcp'))).toBe(true);
|
||||
|
||||
controller.disable();
|
||||
await controller.fetch(config.url, { method: 'POST', body: '{}' });
|
||||
|
||||
expect(logs).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('redacts all configured custom headers in diagnostic request logs', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const response = new Response('{}', {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'application/json' }
|
||||
});
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://example.com/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP,
|
||||
headers: {
|
||||
'x-auth-token': 'secret-token',
|
||||
'x-vendor-api-key': 'secret-key'
|
||||
}
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log), {
|
||||
headers: config.headers
|
||||
});
|
||||
|
||||
await controller.fetch(config.url, {
|
||||
method: 'POST',
|
||||
headers: { 'content-type': 'application/json' },
|
||||
body: '{}'
|
||||
});
|
||||
|
||||
expect(logs).toHaveLength(2);
|
||||
expect(logs[0].details).toMatchObject({
|
||||
request: {
|
||||
headers: {
|
||||
'x-auth-token': '[redacted]',
|
||||
'x-vendor-api-key': '[redacted]',
|
||||
'content-type': 'application/json'
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('partially redacts mcp-session-id in diagnostic request and response logs', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const response = new Response('{}', {
|
||||
status: 200,
|
||||
headers: {
|
||||
'content-type': 'application/json',
|
||||
'mcp-session-id': 'session-response-67890'
|
||||
}
|
||||
});
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://example.com/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
|
||||
|
||||
await controller.fetch(config.url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'content-type': 'application/json',
|
||||
'mcp-session-id': 'session-request-12345'
|
||||
},
|
||||
body: '{}'
|
||||
});
|
||||
|
||||
expect(logs).toHaveLength(2);
|
||||
expect(logs[0].details).toMatchObject({
|
||||
request: {
|
||||
headers: {
|
||||
'content-type': 'application/json',
|
||||
'mcp-session-id': '....12345'
|
||||
}
|
||||
}
|
||||
});
|
||||
expect(logs[1].details).toMatchObject({
|
||||
response: {
|
||||
headers: {
|
||||
'content-type': 'application/json',
|
||||
'mcp-session-id': '....67890'
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('extracts JSON-RPC methods without logging the raw request body', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const response = new Response('{}', {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'application/json' }
|
||||
});
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://example.com/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
|
||||
|
||||
await controller.fetch(config.url, {
|
||||
method: 'POST',
|
||||
body: JSON.stringify([
|
||||
{ jsonrpc: '2.0', id: 1, method: 'initialize' },
|
||||
{ jsonrpc: '2.0', method: 'notifications/initialized' }
|
||||
])
|
||||
});
|
||||
|
||||
expect(logs[0].details).toMatchObject({
|
||||
request: {
|
||||
method: 'POST',
|
||||
body: {
|
||||
kind: 'string',
|
||||
size: expect.any(Number)
|
||||
},
|
||||
jsonRpcMethods: ['initialize', 'notifications/initialized']
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('adds a CORS hint to Failed to fetch diagnostic log messages', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const fetchError = new TypeError('Failed to fetch');
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockRejectedValue(fetchError));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'http://localhost:8000/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
|
||||
|
||||
await expect(controller.fetch(config.url, { method: 'POST', body: '{}' })).rejects.toThrow(
|
||||
'Failed to fetch'
|
||||
);
|
||||
|
||||
expect(logs).toHaveLength(2);
|
||||
expect(logs[1].message).toBe(
|
||||
'HTTP POST http://localhost:8000/mcp failed: Failed to fetch (check CORS?)'
|
||||
);
|
||||
});
|
||||
|
||||
it('detaches phase error logging after the initialize handshake completes', async () => {
|
||||
const phaseLogs: Array<{ phase: MCPConnectionPhase; log: MCPConnectionLog }> = [];
|
||||
const stopPhaseLogging = vi.fn();
|
||||
let emitClientError: ((error: Error) => void) | undefined;
|
||||
|
||||
vi.spyOn(MCPService, 'createTransport').mockReturnValue({
|
||||
transport: {} as never,
|
||||
type: MCPTransportType.WEBSOCKET,
|
||||
stopPhaseLogging
|
||||
});
|
||||
vi.spyOn(MCPService, 'listTools').mockResolvedValue([]);
|
||||
vi.spyOn(Client.prototype, 'getServerVersion').mockReturnValue(undefined);
|
||||
vi.spyOn(Client.prototype, 'getServerCapabilities').mockReturnValue(undefined);
|
||||
vi.spyOn(Client.prototype, 'getInstructions').mockReturnValue(undefined);
|
||||
vi.spyOn(Client.prototype, 'connect').mockImplementation(async function (this: Client) {
|
||||
emitClientError = (error: Error) => this.onerror?.(error);
|
||||
this.onerror?.(new Error('handshake protocol error'));
|
||||
});
|
||||
|
||||
await MCPService.connect(
|
||||
'test-server',
|
||||
{
|
||||
url: 'ws://example.com/mcp',
|
||||
transport: MCPTransportType.WEBSOCKET
|
||||
},
|
||||
undefined,
|
||||
undefined,
|
||||
(phase, log) => phaseLogs.push({ phase, log })
|
||||
);
|
||||
|
||||
expect(stopPhaseLogging).toHaveBeenCalledTimes(1);
|
||||
expect(
|
||||
phaseLogs.filter(
|
||||
({ phase, log }) =>
|
||||
phase === MCPConnectionPhase.ERROR &&
|
||||
log.message === 'Protocol error: handshake protocol error'
|
||||
)
|
||||
).toHaveLength(1);
|
||||
|
||||
emitClientError?.(new Error('runtime protocol error'));
|
||||
|
||||
expect(
|
||||
phaseLogs.filter(
|
||||
({ phase, log }) =>
|
||||
phase === MCPConnectionPhase.ERROR &&
|
||||
log.message === 'Protocol error: runtime protocol error'
|
||||
)
|
||||
).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
20
tools/server/webui/tests/unit/redact.test.ts
Normal file
20
tools/server/webui/tests/unit/redact.test.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { redactValue } from '$lib/utils/redact';
|
||||
|
||||
describe('redactValue', () => {
|
||||
it('returns [redacted] by default', () => {
|
||||
expect(redactValue('secret-token')).toBe('[redacted]');
|
||||
});
|
||||
|
||||
it('shows last N characters when showLastChars is provided', () => {
|
||||
expect(redactValue('session-abc12', 5)).toBe('....abc12');
|
||||
});
|
||||
|
||||
it('handles value shorter than showLastChars', () => {
|
||||
expect(redactValue('ab', 5)).toBe('....ab');
|
||||
});
|
||||
|
||||
it('returns [redacted] when showLastChars is 0', () => {
|
||||
expect(redactValue('secret', 0)).toBe('[redacted]');
|
||||
});
|
||||
});
|
||||
124
tools/server/webui/tests/unit/request-helpers.test.ts
Normal file
124
tools/server/webui/tests/unit/request-helpers.test.ts
Normal file
@@ -0,0 +1,124 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import {
|
||||
getRequestUrl,
|
||||
getRequestMethod,
|
||||
getRequestBody,
|
||||
summarizeRequestBody,
|
||||
formatDiagnosticErrorMessage,
|
||||
extractJsonRpcMethods
|
||||
} from '$lib/utils/request-helpers';
|
||||
|
||||
describe('getRequestUrl', () => {
|
||||
it('returns a plain string input as-is', () => {
|
||||
expect(getRequestUrl('https://example.com/mcp')).toBe('https://example.com/mcp');
|
||||
});
|
||||
|
||||
it('returns href from a URL object', () => {
|
||||
expect(getRequestUrl(new URL('https://example.com/mcp'))).toBe('https://example.com/mcp');
|
||||
});
|
||||
|
||||
it('returns url from a Request object', () => {
|
||||
const req = new Request('https://example.com/mcp');
|
||||
expect(getRequestUrl(req)).toBe('https://example.com/mcp');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getRequestMethod', () => {
|
||||
it('prefers method from init', () => {
|
||||
expect(getRequestMethod('https://example.com', { method: 'POST' })).toBe('POST');
|
||||
});
|
||||
|
||||
it('falls back to Request.method', () => {
|
||||
const req = new Request('https://example.com', { method: 'PUT' });
|
||||
expect(getRequestMethod(req)).toBe('PUT');
|
||||
});
|
||||
|
||||
it('falls back to baseInit.method', () => {
|
||||
expect(getRequestMethod('https://example.com', undefined, { method: 'DELETE' })).toBe('DELETE');
|
||||
});
|
||||
|
||||
it('defaults to GET', () => {
|
||||
expect(getRequestMethod('https://example.com')).toBe('GET');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getRequestBody', () => {
|
||||
it('returns body from init', () => {
|
||||
expect(getRequestBody('https://example.com', { body: 'payload' })).toBe('payload');
|
||||
});
|
||||
|
||||
it('returns undefined when no body is present', () => {
|
||||
expect(getRequestBody('https://example.com')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('summarizeRequestBody', () => {
|
||||
it('returns empty for null', () => {
|
||||
expect(summarizeRequestBody(null)).toEqual({ kind: 'empty' });
|
||||
});
|
||||
|
||||
it('returns empty for undefined', () => {
|
||||
expect(summarizeRequestBody(undefined)).toEqual({ kind: 'empty' });
|
||||
});
|
||||
|
||||
it('returns string kind with size', () => {
|
||||
expect(summarizeRequestBody('hello')).toEqual({ kind: 'string', size: 5 });
|
||||
});
|
||||
|
||||
it('returns blob kind with size', () => {
|
||||
const blob = new Blob(['abc']);
|
||||
expect(summarizeRequestBody(blob)).toEqual({ kind: 'blob', size: 3 });
|
||||
});
|
||||
|
||||
it('returns formdata kind', () => {
|
||||
expect(summarizeRequestBody(new FormData())).toEqual({ kind: 'formdata' });
|
||||
});
|
||||
|
||||
it('returns arraybuffer kind with size', () => {
|
||||
expect(summarizeRequestBody(new ArrayBuffer(8))).toEqual({ kind: 'arraybuffer', size: 8 });
|
||||
});
|
||||
});
|
||||
|
||||
describe('formatDiagnosticErrorMessage', () => {
|
||||
it('appends CORS hint for Failed to fetch', () => {
|
||||
expect(formatDiagnosticErrorMessage(new TypeError('Failed to fetch'))).toBe(
|
||||
'Failed to fetch (check CORS?)'
|
||||
);
|
||||
});
|
||||
|
||||
it('passes through other error messages unchanged', () => {
|
||||
expect(formatDiagnosticErrorMessage(new Error('timeout'))).toBe('timeout');
|
||||
});
|
||||
|
||||
it('handles non-Error values', () => {
|
||||
expect(formatDiagnosticErrorMessage('some string')).toBe('some string');
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractJsonRpcMethods', () => {
|
||||
it('extracts methods from a JSON-RPC array', () => {
|
||||
const body = JSON.stringify([
|
||||
{ jsonrpc: '2.0', id: 1, method: 'initialize' },
|
||||
{ jsonrpc: '2.0', method: 'notifications/initialized' }
|
||||
]);
|
||||
expect(extractJsonRpcMethods(body)).toEqual(['initialize', 'notifications/initialized']);
|
||||
});
|
||||
|
||||
it('extracts method from a single JSON-RPC message', () => {
|
||||
const body = JSON.stringify({ jsonrpc: '2.0', id: 1, method: 'tools/list' });
|
||||
expect(extractJsonRpcMethods(body)).toEqual(['tools/list']);
|
||||
});
|
||||
|
||||
it('returns undefined for non-string body', () => {
|
||||
expect(extractJsonRpcMethods(null)).toBeUndefined();
|
||||
expect(extractJsonRpcMethods(undefined)).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns undefined for invalid JSON', () => {
|
||||
expect(extractJsonRpcMethods('not json')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns undefined when no methods found', () => {
|
||||
expect(extractJsonRpcMethods(JSON.stringify({ foo: 'bar' }))).toBeUndefined();
|
||||
});
|
||||
});
|
||||
55
tools/server/webui/tests/unit/sanitize-headers.test.ts
Normal file
55
tools/server/webui/tests/unit/sanitize-headers.test.ts
Normal file
@@ -0,0 +1,55 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { sanitizeHeaders } from '$lib/utils/api-headers';
|
||||
|
||||
describe('sanitizeHeaders', () => {
|
||||
it('returns empty object for undefined input', () => {
|
||||
expect(sanitizeHeaders()).toEqual({});
|
||||
});
|
||||
|
||||
it('passes through non-sensitive headers', () => {
|
||||
const headers = new Headers({ 'content-type': 'application/json', accept: 'text/html' });
|
||||
expect(sanitizeHeaders(headers)).toEqual({
|
||||
'content-type': 'application/json',
|
||||
accept: 'text/html'
|
||||
});
|
||||
});
|
||||
|
||||
it('redacts known sensitive headers', () => {
|
||||
const headers = new Headers({
|
||||
authorization: 'Bearer secret',
|
||||
'x-api-key': 'key-123',
|
||||
'content-type': 'application/json'
|
||||
});
|
||||
const result = sanitizeHeaders(headers);
|
||||
expect(result.authorization).toBe('[redacted]');
|
||||
expect(result['x-api-key']).toBe('[redacted]');
|
||||
expect(result['content-type']).toBe('application/json');
|
||||
});
|
||||
|
||||
it('partially redacts headers specified in partialRedactHeaders', () => {
|
||||
const headers = new Headers({ 'mcp-session-id': 'session-12345' });
|
||||
const partial = new Map([['mcp-session-id', 5]]);
|
||||
expect(sanitizeHeaders(headers, undefined, partial)['mcp-session-id']).toBe('....12345');
|
||||
});
|
||||
|
||||
it('fully redacts mcp-session-id when no partialRedactHeaders is given', () => {
|
||||
const headers = new Headers({ 'mcp-session-id': 'session-12345' });
|
||||
expect(sanitizeHeaders(headers)['mcp-session-id']).toBe('[redacted]');
|
||||
});
|
||||
|
||||
it('redacts extra headers provided by the caller', () => {
|
||||
const headers = new Headers({
|
||||
'x-vendor-key': 'vendor-secret',
|
||||
'content-type': 'application/json'
|
||||
});
|
||||
const result = sanitizeHeaders(headers, ['x-vendor-key']);
|
||||
expect(result['x-vendor-key']).toBe('[redacted]');
|
||||
expect(result['content-type']).toBe('application/json');
|
||||
});
|
||||
|
||||
it('handles case-insensitive extra header names', () => {
|
||||
const headers = new Headers({ 'X-Custom-Token': 'token-value' });
|
||||
const result = sanitizeHeaders(headers, ['X-CUSTOM-TOKEN']);
|
||||
expect(result['x-custom-token']).toBe('[redacted]');
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user